diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -109,6 +109,7 @@ // integer tensor. The custom assembly form of the operation is as follows // // i %0, %1 : i32 +// class IntArithmeticOp traits = []> : ArithmeticOp, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>; @@ -121,10 +122,23 @@ // is as follows // // f %0, %1 : f32 +// class FloatArithmeticOp traits = []> : ArithmeticOp, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; +// Base class for standard arithmetic operations on complex numbers with a +// floating-point element type. +// These operations take two operands and return one result, all of which must +// be complex numbers of the same type. +// The assembly format is as follows +// +// cf %0, %1 : complex +// +class ComplexFloatArithmeticOp traits = []> : + ArithmeticOp, + Arguments<(ins Complex:$lhs, Complex:$rhs)>; + // Base class for memref allocating ops: alloca and alloc. // // %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)> @@ -201,6 +215,26 @@ }]; } +//===----------------------------------------------------------------------===// +// AddCFOp +//===----------------------------------------------------------------------===// + +def AddCFOp : ComplexFloatArithmeticOp<"addcf"> { + let summary = "complex number addition"; + let description = [{ + The `addcf` operation takes two complex number operands and returns their + sum, a single complex number. + All operands and result must be of the same type, a complex number with a + floating-point element type. + + Example: + + ```mlir + %a = addcf %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -2407,6 +2441,26 @@ }]; } +//===----------------------------------------------------------------------===// +// SubCFOp +//===----------------------------------------------------------------------===// + +def SubCFOp : ComplexFloatArithmeticOp<"subcf"> { + let summary = "complex number subtraction"; + let description = [{ + The `subcf` operation takes two complex number operands and returns their + difference, a single complex number. + All operands and result must be of the same type, a complex number with a + floating-point element type. + + Example: + + ```mlir + %a = subcf %b, %c : complex + ``` + }]; +} + //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -719,7 +719,6 @@ SignlessIntegerLike.predicate, FloatLike.predicate]>, "signless-integer-like or floating-point-like">; - //===----------------------------------------------------------------------===// // Attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -443,12 +443,12 @@ return extractPtr(builder, loc, kRealPosInComplexNumberStruct); } -void ComplexStructBuilder ::setImaginary(OpBuilder &builder, Location loc, - Value imaginary) { +void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, + Value imaginary) { setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); } -Value ComplexStructBuilder ::imaginary(OpBuilder &builder, Location loc) { +Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); } @@ -1326,8 +1326,7 @@ OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; -// Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and -// `ImOp`. +// Lowerings for operations on complex numbers. struct CreateComplexOpLowering : public ConvertOpToLLVMPattern { @@ -1385,6 +1384,82 @@ } }; +struct BinaryComplexOperands { + Value lhsReal, lhsImag, rhsReal, rhsImag; +}; + +template +BinaryComplexOperands +unpackBinaryComplexOperands(OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto bop = cast(op); + auto loc = bop.getLoc(); + OperandAdaptor transformed(operands); + + // Extract real and imaginary values from operands. + BinaryComplexOperands unpacked; + ComplexStructBuilder lhs(transformed.lhs()); + unpacked.lhsReal = lhs.real(rewriter, loc); + unpacked.lhsImag = lhs.imaginary(rewriter, loc); + ComplexStructBuilder rhs(transformed.rhs()); + unpacked.rhsReal = rhs.real(rewriter, loc); + unpacked.rhsImag = rhs.imaginary(rewriter, loc); + + return unpacked; +} + +struct AddCFOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto op = cast(operation); + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = this->typeConverter.convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to add complex numbers. + Value real = rewriter.create(loc, arg.lhsReal, arg.rhsReal); + Value imag = rewriter.create(loc, arg.lhsImag, arg.rhsImag); + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +struct SubCFOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto op = cast(operation); + auto loc = op.getLoc(); + BinaryComplexOperands arg = + unpackBinaryComplexOperands(op, operands, rewriter); + + // Initialize complex number struct for result. + auto structType = this->typeConverter.convertType(op.getType()); + auto result = ComplexStructBuilder::undef(rewriter, loc, structType); + + // Emit IR to substract complex numbers. + Value real = rewriter.create(loc, arg.lhsReal, arg.rhsReal); + Value imag = rewriter.create(loc, arg.lhsImag, arg.rhsImag); + result.setReal(rewriter, loc, real); + result.setImaginary(rewriter, loc, imag); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { @@ -2874,6 +2949,7 @@ // clang-format off patterns.insert< AbsFOpLowering, + AddCFOpLowering, AddFOpLowering, AddIOpLowering, AllocaOpLowering, @@ -2921,6 +2997,7 @@ SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, + SubCFOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -83,6 +83,48 @@ return } +// CHECK-LABEL: llvm.func @complex_addition() +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }"> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }"> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : !llvm.double +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : !llvm.double +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }"> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }"> +func @complex_addition() { + %a_re = constant 1.2 : f64 + %a_im = constant 3.4 : f64 + %a = create_complex %a_re, %a_im : complex + %b_re = constant 5.6 : f64 + %b_im = constant 7.8 : f64 + %b = create_complex %b_re, %b_im : complex + %c = addcf %a, %b : complex + return +} + +// CHECK-LABEL: llvm.func @complex_substraction() +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }"> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }"> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }"> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : !llvm.double +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : !llvm.double +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }"> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }"> +func @complex_substraction() { + %a_re = constant 1.2 : f64 + %a_im = constant 3.4 : f64 + %a = create_complex %a_re, %a_im : complex + %b_re = constant 5.6 : f64 + %b_im = constant 7.8 : f64 + %b = create_complex %b_re, %b_im : complex + %c = subcf %a, %b : complex + return +} + // CHECK-LABEL: func @simple_caller() { // CHECK-NEXT: llvm.call @simple_loop() : () -> () // CHECK-NEXT: llvm.return