diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -109,6 +109,7 @@ // integer tensor. The custom assembly form of the operation is as follows // // i %0, %1 : i32 +// class IntArithmeticOp traits = []> : ArithmeticOp, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>; @@ -121,10 +122,22 @@ // is as follows // // f %0, %1 : f32 +// class FloatArithmeticOp traits = []> : ArithmeticOp, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; +// Base class for standard arithmetic operations on complex numbers. +// These operations take two operands and return one result, all of which must +// be complex numbers of the same type. +// The assembly format is as follows +// +// c %0, %1 : complex +// +class ComplexArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins Complex:$lhs, Complex:$rhs)>; + // Base class for memref allocating ops: alloca and alloc. // // %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)> @@ -201,6 +214,25 @@ }]; } +//===----------------------------------------------------------------------===// +// AddCOp +//===----------------------------------------------------------------------===// + +def AddCOp : ComplexArithmeticOp<"addc"> { + let summary = "complex number addition"; + let description = [{ + The `addc` operation takes two complex number operands and returns their + sum, a single complex number. + All operands and result must be of the same type. + + Example: + + ```mlir + %a = addc %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -2407,6 +2439,25 @@ }]; } +//===----------------------------------------------------------------------===// +// SubCOp +//===----------------------------------------------------------------------===// + +def SubCOp : ComplexArithmeticOp<"subc"> { + let summary = "complex number subtraction"; + let description = [{ + The `subc` operation takes two complex number operands and returns their + difference, a single complex number. + All operands and the result must be of the same type. + + Example: + + ```mlir + %a = subc %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -719,7 +719,6 @@ SignlessIntegerLike.predicate, FloatLike.predicate]>, "signless-integer-like or floating-point-like">; - //===----------------------------------------------------------------------===// // Attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1326,8 +1326,7 @@ OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; -// Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and -// `ImOp`. +// Lowerings for operations on complex numbers. struct CreateComplexOpLowering : public ConvertOpToLLVMPattern { @@ -1385,6 +1384,76 @@ } }; +template +struct ConvertBinaryOpOnComplexNumbersToLLVMPattern + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto bop = cast(op); + auto loc = bop.getLoc(); + OperandAdaptor transformed(operands); + + // Extract real and imaginary values from operands. + ComplexStructBuilder lhs(transformed.lhs()); + Value lhsReal = lhs.real(rewriter, loc); + Value lhsImag = lhs.imaginary(rewriter, loc); + ComplexStructBuilder rhs(transformed.rhs()); + Value rhsReal = rhs.real(rewriter, loc); + Value rhsImag = rhs.imaginary(rewriter, loc); + + // Initialize complex number struct for result. + auto structType = this->typeConverter.convertType(bop.getType()); + auto result = ComplexStructBuilder ::undef(rewriter, loc, structType); + + // Emit operation-specific LLVM IR. + this->emitOpSpecificLLVMIR(rewriter, loc, lhsReal, lhsImag, rhsReal, + rhsImag, result); + + rewriter.replaceOp(op, {result}); + return success(); + } + + virtual void emitOpSpecificLLVMIR(ConversionPatternRewriter &rewriter, + Location loc, Value lhsReal, Value lhsImag, + Value rhsReal, Value rhsImag, + ComplexStructBuilder &result) const = 0; +}; + +struct AddCOpLowering + : public ConvertBinaryOpOnComplexNumbersToLLVMPattern { + using ConvertBinaryOpOnComplexNumbersToLLVMPattern< + AddCOp>::ConvertBinaryOpOnComplexNumbersToLLVMPattern; + + void emitOpSpecificLLVMIR(ConversionPatternRewriter &rewriter, Location loc, + Value lhsReal, Value lhsImag, Value rhsReal, + Value rhsImag, + ComplexStructBuilder &result) const override { + Value real = rewriter.create(loc, lhsReal, rhsReal); + Value imag = rewriter.create(loc, lhsImag, rhsImag); + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + } +}; + +struct SubCOpLowering + : public ConvertBinaryOpOnComplexNumbersToLLVMPattern { + using ConvertBinaryOpOnComplexNumbersToLLVMPattern< + SubCOp>::ConvertBinaryOpOnComplexNumbersToLLVMPattern; + + void emitOpSpecificLLVMIR(ConversionPatternRewriter &rewriter, Location loc, + Value lhsReal, Value lhsImag, Value rhsReal, + Value rhsImag, + ComplexStructBuilder &result) const override { + Value real = rewriter.create(loc, lhsReal, rhsReal); + Value imag = rewriter.create(loc, lhsImag, rhsImag); + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + } +}; + // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { @@ -2981,6 +3050,7 @@ // clang-format off patterns.insert< AbsFOpLowering, + AddCOpLowering, AddFOpLowering, AddIOpLowering, AllocaOpLowering, @@ -3029,6 +3099,7 @@ SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, + SubCOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -83,6 +83,48 @@ return } +// CHECK-LABEL: llvm.func @complex_addition() +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }"> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }"> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : !llvm.double +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : !llvm.double +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }"> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }"> +func @complex_addition() { + %a_re = constant 1.2 : f64 + %a_im = constant 3.4 : f64 + %a = create_complex %a_re, %a_im : complex + %b_re = constant 5.6 : f64 + %b_im = constant 7.8 : f64 + %b = create_complex %b_re, %b_im : complex + %c = addc %a, %b : complex + return +} + +// CHECK-LABEL: llvm.func @complex_substraction() +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }"> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }"> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : !llvm.double +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : !llvm.double +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }"> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }"> +func @complex_substraction() { + %a_re = constant 1.2 : f64 + %a_im = constant 3.4 : f64 + %a = create_complex %a_re, %a_im : complex + %b_re = constant 5.6 : f64 + %b_im = constant 7.8 : f64 + %b = create_complex %b_re, %b_im : complex + %c = subc %a, %b : complex + return +} + // CHECK-LABEL: func @simple_caller() { // CHECK-NEXT: llvm.call @simple_loop() : () -> () // CHECK-NEXT: llvm.return