diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -48,6 +48,32 @@ }]; } +//===----------------------------------------------------------------------===// +// AbsOp +//===----------------------------------------------------------------------===// + +def AbsOp : Complex_Op<"abs", + [NoSideEffect, + TypesMatchWith<"complex element type matches result type", + "complex", "result", + "$_self.cast().getElementType()">]> { + let summary = "computes absolute value of a complex number"; + let description = [{ + The `abs` op takes a single complex number and computes its absolute value. + + Example: + + ```mlir + %a = abs %b : complex + ``` + }]; + + let arguments = (ins Complex:$complex); + let results = (outs AnyFloat:$result); + + let assemblyFormat = "$complex attr-dict `:` type($complex)"; +} + //===----------------------------------------------------------------------===// // CreateOp //===----------------------------------------------------------------------===// @@ -80,6 +106,22 @@ let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)"; } +//===----------------------------------------------------------------------===// +// DivOp +//===----------------------------------------------------------------------===// + +def DivOp : ComplexArithmeticOp<"div"> { + let summary = "complex division"; + let description = [{ + The `div` operation takes two complex numbers and returns result of their + division: + + ```mlir + %a = div %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// @@ -106,6 +148,21 @@ let assemblyFormat = "$complex attr-dict `:` type($complex)"; } +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : ComplexArithmeticOp<"mul"> { + let summary = "complex multiplication"; + let description = [{ + The `mul` operation takes two complex numbers and returns their product: + + ```mlir + %a = mul %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -17,6 +17,29 @@ namespace { +struct AbsOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::AbsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::AbsOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + + ComplexStructBuilder complexStruct(transformed.complex()); + Value real = complexStruct.real(rewriter, op.getLoc()); + Value imag = complexStruct.imaginary(rewriter, op.getLoc()); + + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value sq_norm = rewriter.create( + loc, rewriter.create(loc, real, real, fmf), + rewriter.create(loc, imag, imag, fmf), fmf); + + rewriter.replaceOpWithNewOp(op, sq_norm); + return success(); + } +}; + struct CreateOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -123,6 +146,88 @@ } }; +struct DivOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::DivOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = typeConverter->convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to add complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value rhs_re = arg.rhs.real(); + Value rhs_im = arg.rhs.imag(); + Value lhs_re = arg.lhs.real(); + Value lhs_im = arg.lhs.imag(); + + Value rhs_sq_norm = rewriter.create( + loc, rewriter.create(loc, rhs_re, rhs_re, fmf), + rewriter.create(loc, rhs_im, rhs_im, fmf), fmf); + + Value result_real = rewriter.create( + loc, rewriter.create(loc, lhs_re, rhs_re, fmf), + rewriter.create(loc, lhs_im, rhs_im, fmf), fmf); + + Value result_imag = rewriter.create( + loc, rewriter.create(loc, lhs_im, rhs_re, fmf), + rewriter.create(loc, lhs_re, rhs_im, fmf), fmf); + + result.setReal( + rewriter, loc, + rewriter.create(loc, result_real, rhs_sq_norm, fmf)); + result.setImaginary( + rewriter, loc, + rewriter.create(loc, result_imag, rhs_sq_norm, fmf)); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +struct MulOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::MulOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = typeConverter->convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to add complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value rhs_re = arg.rhs.real(); + Value rhs_im = arg.rhs.imag(); + Value lhs_re = arg.lhs.real(); + Value lhs_im = arg.lhs.imag(); + + Value real = rewriter.create( + loc, rewriter.create(loc, rhs_re, lhs_re, fmf), + rewriter.create(loc, rhs_im, lhs_im, fmf), fmf); + + Value imag = rewriter.create( + loc, rewriter.create(loc, lhs_im, rhs_re, fmf), + rewriter.create(loc, lhs_re, rhs_im, fmf), fmf); + + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + struct SubOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -157,7 +262,10 @@ // clang-format off patterns.insert< AddOpConversion, + AbsOpConversion, CreateOpConversion, + DivOpConversion, + MulOpConversion, ImOpConversion, ReOpConversion, SubOpConversion diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s +// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm +// | FileCheck %s // CHECK-LABEL: llvm.func @complex_numbers() // CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32 @@ -59,3 +60,80 @@ %c = complex.sub %a, %b : complex return } + +// CHECK-LABEL: llvm.func @complex_div +func @complex_div(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs : complex + return %div : complex +} +// CHECK-LABEL: llvm.func @complex_div( +// CHECK-SAME: [[LHS:%.*]]: [[C_TY:!.*]], [[RHS:%.*]]: [[C_TY]]) -> [[C_TY]] + +// CHECK: [[LHS_RE:%.*]] = llvm.extractvalue [[LHS]][0] : [[C_TY]] +// CHECK: [[LHS_IM:%.*]] = llvm.extractvalue [[LHS]][1] : [[C_TY]] +// CHECK: [[RHS_RE:%.*]] = llvm.extractvalue [[RHS]][0] : [[C_TY]] +// CHECK: [[RHS_IM:%.*]] = llvm.extractvalue [[RHS]][1] : [[C_TY]] + +// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef : [[C_TY]] + +// CHECK: [[RHS_RE_SQ:%.*]] = llvm.fmul [[RHS_RE]], [[RHS_RE]] : f32 +// CHECK: [[RHS_IM_SQ:%.*]] = llvm.fmul [[RHS_IM]], [[RHS_IM]] : f32 +// CHECK: [[SQ_NORM:%.*]] = llvm.fadd [[RHS_RE_SQ]], [[RHS_IM_SQ]] : f32 + +// CHECK: [[REAL_TMP_0:%.*]] = llvm.fmul [[LHS_RE]], [[RHS_RE]] : f32 +// CHECK: [[REAL_TMP_1:%.*]] = llvm.fmul [[LHS_IM]], [[RHS_IM]] : f32 +// CHECK: [[REAL_TMP_2:%.*]] = llvm.fadd [[REAL_TMP_0]], [[REAL_TMP_1]] : f32 + +// CHECK: [[IMAG_TMP_0:%.*]] = llvm.fmul [[LHS_IM]], [[RHS_RE]] : f32 +// CHECK: [[IMAG_TMP_1:%.*]] = llvm.fmul [[LHS_RE]], [[RHS_IM]] : f32 +// CHECK: [[IMAG_TMP_2:%.*]] = llvm.fsub [[IMAG_TMP_0]], [[IMAG_TMP_1]] : f32 + +// CHECK: [[REAL:%.*]] = llvm.fdiv [[REAL_TMP_2]], [[SQ_NORM]] : f32 +// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[REAL]], [[RESULT_0]][0] : [[C_TY]] +// CHECK: [[IMAG:%.*]] = llvm.fdiv [[IMAG_TMP_2]], [[SQ_NORM]] : f32 +// CHECK: [[RESULT_2:%.*]] = llvm.insertvalue [[IMAG]], [[RESULT_1]][1] : [[C_TY]] +// CHECK: llvm.return [[RESULT_2]] : [[C_TY]] + + +// CHECK-LABEL: llvm.func @complex_mul +func @complex_mul(%lhs: complex, %rhs: complex) -> complex { + %mul = complex.mul %lhs, %rhs : complex + return %mul : complex +} +// CHECK-LABEL: llvm.func @complex_mul( +// CHECK-SAME: [[LHS:%.*]]: [[C_TY:!.*]], [[RHS:%.*]]: [[C_TY]]) -> [[C_TY]] + +// CHECK: [[LHS_RE:%.*]] = llvm.extractvalue [[LHS]][0] : [[C_TY]] +// CHECK: [[LHS_IM:%.*]] = llvm.extractvalue [[LHS]][1] : [[C_TY]] +// CHECK: [[RHS_RE:%.*]] = llvm.extractvalue [[RHS]][0] : [[C_TY]] +// CHECK: [[RHS_IM:%.*]] = llvm.extractvalue [[RHS]][1] : [[C_TY]] +// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef : [[C_TY]] + +// CHECK: [[REAL_TMP_0:%.*]] = llvm.fmul [[RHS_RE]], [[LHS_RE]] : f32 +// CHECK: [[REAL_TMP_1:%.*]] = llvm.fmul [[RHS_IM]], [[LHS_IM]] : f32 +// CHECK: [[REAL:%.*]] = llvm.fsub [[REAL_TMP_0]], [[REAL_TMP_1]] : f32 + +// CHECK: [[IMAG_TMP_0:%.*]] = llvm.fmul [[LHS_IM]], [[RHS_RE]] : f32 +// CHECK: [[IMAG_TMP_1:%.*]] = llvm.fmul [[LHS_RE]], [[RHS_IM]] : f32 +// CHECK: [[IMAG:%.*]] = llvm.fadd [[IMAG_TMP_0]], [[IMAG_TMP_1]] : f32 + +// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[REAL]], [[RESULT_0]][0] +// CHECK: [[RESULT_2:%.*]] = llvm.insertvalue [[IMAG]], [[RESULT_1]][1] +// CHECK: llvm.return [[RESULT_2]] : [[C_TY]] + +// CHECK-LABEL: llvm.func @complex_abs +func @complex_abs(%arg: complex) -> f32 { + %abs = complex.abs %arg: complex + return %abs : f32 +} + +// CHECK-LABEL: llvm.func @complex_abs( +// CHECK-SAME: [[LHS:%.*]]: [[C_TY:!.*]]) +// CHECK: [[REAL:%.*]] = llvm.extractvalue [[VAL_0]][0] : !llvm.struct<(f32, f32)> +// CHECK: [[IMAG:%.*]] = llvm.extractvalue [[VAL_0]][1] : !llvm.struct<(f32, f32)> +// CHECK: [[REAL_SQ:%.*]] = llvm.fmul [[REAL]], [[REAL]] : f32 +// CHECK: [[IMAG_SQ:%.*]] = llvm.fmul [[IMAG]], [[IMAG]] : f32 +// CHECK: [[SQ_NORM:%.*]] = llvm.fadd [[REAL_SQ]], [[IMAG_SQ]] : f32 +// CHECK: [[NORM:%.*]] = "llvm.intr.sqrt"([[SQ_NORM]]) : (f32) -> f32 +// CHECK: llvm.return [[VAL_6]] : f32 + diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -14,9 +14,18 @@ // CHECK: complex.im [[C]] : complex %imag = complex.im %complex : complex + // CHECK: complex.abs [[C]] : complex + %abs = complex.abs %complex : complex + // CHECK: complex.add [[C]], [[C]] : complex %sum = complex.add %complex, %complex : complex + // CHECK: complex.div [[C]], [[C]] : complex + %div = complex.div %complex, %complex : complex + + // CHECK: complex.mul [[C]], [[C]] : complex + %prod = complex.mul %complex, %complex : complex + // CHECK: complex.sub [[C]], [[C]] : complex %diff = complex.sub %complex, %complex : complex return