diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -48,6 +48,32 @@ }]; } +//===----------------------------------------------------------------------===// +// AbsOp +//===----------------------------------------------------------------------===// + +def AbsOp : Complex_Op<"abs", + [NoSideEffect, + TypesMatchWith<"complex element type matches result type", + "complex", "result", + "$_self.cast().getElementType()">]> { + let summary = "computes absolute value of a complex number"; + let description = [{ + The `abs` op takes a single complex number and computes its absolute value. + + Example: + + ```mlir + %a = abs %b : complex + ``` + }]; + + let arguments = (ins Complex:$complex); + let results = (outs AnyFloat:$result); + + let assemblyFormat = "$complex attr-dict `:` type($complex)"; +} + //===----------------------------------------------------------------------===// // CreateOp //===----------------------------------------------------------------------===// @@ -80,6 +106,22 @@ let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)"; } +//===----------------------------------------------------------------------===// +// DivOp +//===----------------------------------------------------------------------===// + +def DivOp : ComplexArithmeticOp<"div"> { + let summary = "complex division"; + let description = [{ + The `div` operation takes two complex numbers and returns result of their + division: + + ```mlir + %a = div %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// @@ -106,6 +148,21 @@ let assemblyFormat = "$complex attr-dict `:` type($complex)"; } +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : ComplexArithmeticOp<"mul"> { + let summary = "complex multiplication"; + let description = [{ + The `mul` operation takes two complex numbers and returns their product: + + ```mlir + %a = mul %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -17,6 +17,29 @@ namespace { +struct AbsOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::AbsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::AbsOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + + ComplexStructBuilder complexStruct(transformed.complex()); + Value real = complexStruct.real(rewriter, op.getLoc()); + Value imag = complexStruct.imaginary(rewriter, op.getLoc()); + + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value sqNorm = rewriter.create( + loc, rewriter.create(loc, real, real, fmf), + rewriter.create(loc, imag, imag, fmf), fmf); + + rewriter.replaceOpWithNewOp(op, sqNorm); + return success(); + } +}; + struct CreateOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -123,6 +146,88 @@ } }; +struct DivOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::DivOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = typeConverter->convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to add complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value rhsRe = arg.rhs.real(); + Value rhsIm = arg.rhs.imag(); + Value lhsRe = arg.lhs.real(); + Value lhsIm = arg.lhs.imag(); + + Value rhsSqNorm = rewriter.create( + loc, rewriter.create(loc, rhsRe, rhsRe, fmf), + rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); + + Value resultReal = rewriter.create( + loc, rewriter.create(loc, lhsRe, rhsRe, fmf), + rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); + + Value resultImag = rewriter.create( + loc, rewriter.create(loc, lhsIm, rhsRe, fmf), + rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + + result.setReal( + rewriter, loc, + rewriter.create(loc, resultReal, rhsSqNorm, fmf)); + result.setImaginary( + rewriter, loc, + rewriter.create(loc, resultImag, rhsSqNorm, fmf)); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +struct MulOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(complex::MulOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = typeConverter->convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to add complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); + Value rhsRe = arg.rhs.real(); + Value rhsIm = arg.rhs.imag(); + Value lhsRe = arg.lhs.real(); + Value lhsIm = arg.lhs.imag(); + + Value real = rewriter.create( + loc, rewriter.create(loc, rhsRe, lhsRe, fmf), + rewriter.create(loc, rhsIm, lhsIm, fmf), fmf); + + Value imag = rewriter.create( + loc, rewriter.create(loc, lhsIm, rhsRe, fmf), + rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); + + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + struct SubOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -156,9 +261,12 @@ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.insert< + AbsOpConversion, AddOpConversion, CreateOpConversion, + DivOpConversion, ImOpConversion, + MulOpConversion, ReOpConversion, SubOpConversion >(converter); diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -18,6 +18,8 @@ return } +// ----- + // CHECK-LABEL: llvm.func @complex_addition() // CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> // CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> @@ -39,6 +41,8 @@ return } +// ----- + // CHECK-LABEL: llvm.func @complex_substraction() // CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> // CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> @@ -59,3 +63,79 @@ %c = complex.sub %a, %b : complex return } + +// ----- + +// CHECK-LABEL: llvm.func @complex_div +// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]] +func @complex_div(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs : complex + return %div : complex +} +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]] + +// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] + +// CHECK: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] : f32 +// CHECK: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] : f32 +// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] : f32 + +// CHECK: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] : f32 +// CHECK: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] : f32 +// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32 + +// CHECK: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] : f32 +// CHECK: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] : f32 +// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32 + +// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] : f32 +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] +// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32 +// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] +// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]] + +// ----- + +// CHECK-LABEL: llvm.func @complex_mul +// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]] +func @complex_mul(%lhs: complex, %rhs: complex) -> complex { + %mul = complex.mul %lhs, %rhs : complex + return %mul : complex +} +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]] +// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] + +// CHECK: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] : f32 +// CHECK: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] : f32 +// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32 + +// CHECK: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] : f32 +// CHECK: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] : f32 +// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32 + +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] +// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] +// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]] + +// ----- + +// CHECK-LABEL: llvm.func @complex_abs +// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]]) +func @complex_abs(%arg: complex) -> f32 { + %abs = complex.abs %arg: complex + return %abs : f32 +} +// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]] +// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]] +// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32 +// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32 +// CHECK: llvm.return %[[NORM]] : f32 + diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -3,21 +3,30 @@ // CHECK-LABEL: func @ops( -// CHECK-SAME: [[F:%.*]]: f32) { +// CHECK-SAME: %[[F:.*]]: f32) { func @ops(%f: f32) { - // CHECK: [[C:%.*]] = complex.create [[F]], [[F]] : complex + // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex %complex = complex.create %f, %f : complex - // CHECK: complex.re [[C]] : complex + // CHECK: complex.re %[[C]] : complex %real = complex.re %complex : complex - // CHECK: complex.im [[C]] : complex + // CHECK: complex.im %[[C]] : complex %imag = complex.im %complex : complex - // CHECK: complex.add [[C]], [[C]] : complex + // CHECK: complex.abs %[[C]] : complex + %abs = complex.abs %complex : complex + + // CHECK: complex.add %[[C]], %[[C]] : complex %sum = complex.add %complex, %complex : complex - // CHECK: complex.sub [[C]], [[C]] : complex + // CHECK: complex.div %[[C]], %[[C]] : complex + %div = complex.div %complex, %complex : complex + + // CHECK: complex.mul %[[C]], %[[C]] : complex + %prod = complex.mul %complex, %complex : complex + + // CHECK: complex.sub %[[C]], %[[C]] : complex %diff = complex.sub %complex, %complex : complex return }