diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -26,6 +26,7 @@ namespace mlir { +class ComplexType; class LLVMTypeConverter; class UnrankedMemRefType; @@ -139,24 +140,29 @@ LLVM::LLVMDialect *llvmDialect; private: - // Convert a function type. The arguments and results are converted one by - // one. Additionally, if the function returns more than one value, pack the - // results into an LLVM IR structure type so that the converted function type - // returns at most one result. + /// Convert a function type. The arguments and results are converted one by + /// one. Additionally, if the function returns more than one value, pack the + /// results into an LLVM IR structure type so that the converted function type + /// returns at most one result. Type convertFunctionType(FunctionType type); - // Convert the index type. Uses llvmModule data layout to create an integer - // of the pointer bitwidth. + /// Convert the index type. Uses llvmModule data layout to create an integer + /// of the pointer bitwidth. Type convertIndexType(IndexType type); - // Convert an integer type `i*` to `!llvm<"i*">`. + /// Convert an integer type `i*` to `!llvm<"i*">`. Type convertIntegerType(IntegerType type); - // Convert a floating point type: `f16` to `!llvm.half`, `f32` to - // `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported - // by LLVM. + /// Convert a floating point type: `f16` to `!llvm.half`, `f32` to + /// `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported + /// by LLVM. Type convertFloatType(FloatType type); + /// Convert complex number type: `complex` to `!llvm<"{ half, half }">`, + /// `complex` to `!llvm<"{ float, float }">`, and `complex` to + /// `!llvm<"{ double, double }">`. `complex` is not supported. + Type convertComplexType(ComplexType type); + /// Convert a memref type into an LLVM type that captures the relevant data. Type convertMemRefType(MemRefType type); @@ -221,6 +227,25 @@ void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); }; +class ComplexStructBuilder : public StructBuilder { +public: + /// Construct a helper for the given complex number value. + using StructBuilder::StructBuilder; + /// Build IR creating an `undef` value of the complex number type. + static ComplexStructBuilder undef(OpBuilder &builder, Location loc, + Type type); + + // Build IR extracting the real value from the complex number struct. + Value real(OpBuilder &builder, Location loc); + // Build IR inserting the real value into the complex number struct. + void setReal(OpBuilder &builder, Location loc, Value real); + + // Build IR extracting the imaginary value from the complex number struct. + Value imaginary(OpBuilder &builder, Location loc); + // Build IR inserting the imaginary value into the complex number struct. + void setImaginary(OpBuilder &builder, Location loc, Value imaginary); +}; + /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. /// The Value may be null, in which case none of the operations are valid. @@ -476,8 +501,8 @@ } }; -/// Basic lowering implementation for rewriting from Ops to LLVM Dialect Ops -/// with one result. This supports higher-dimensional vector types. +/// Basic lowering implementation to rewrite Ops with just one result to the +/// LLVM Dialect. This supports higher-dimensional vector types. template class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -19,6 +19,7 @@ using std_alloc = ValueBuilder; using std_alloca = ValueBuilder; using std_call = OperationBuilder; +using std_create_complex = ValueBuilder; using std_constant = ValueBuilder; using std_constant_float = ValueBuilder; using std_constant_index = ValueBuilder; @@ -26,10 +27,12 @@ using std_dealloc = OperationBuilder; using std_dim = ValueBuilder; using std_extract_element = ValueBuilder; +using std_im = ValueBuilder; using std_index_cast = ValueBuilder; using std_muli = ValueBuilder; using std_mulf = ValueBuilder; using std_memref_cast = ValueBuilder; +using std_re = ValueBuilder; using std_ret = OperationBuilder; using std_select = ValueBuilder; using std_load = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -982,6 +982,40 @@ } //===----------------------------------------------------------------------===// +// CreateComplexOp +//===----------------------------------------------------------------------===// + +def CreateComplexOp : Std_Op<"create_complex", + [NoSideEffect, + AllTypesMatch<["real", "imaginary"]>, + TypesMatchWith<"complex element type matches real operand type", + "complex", "real", + "$_self.cast().getElementType()">, + TypesMatchWith<"complex element type matches imaginary operand type", + "complex", "imaginary", + "$_self.cast().getElementType()">]> { + let summary = "creates a complex number"; + let description = [{ + The `create_complex` operation creates a complex number from two + floating-point operands, the real and the imaginary part. + + Example: + + ```mlir + %a = create_complex %b, %c : complex + ``` + }]; + + let arguments = (ins AnyFloat:$real, AnyFloat:$imaginary); + let results = (outs Complex:$complex); + + let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)"; + + // `CreateComplexOp` is fully verified by its traits. + let verifier = ?; +} + +//===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -1498,6 +1532,36 @@ } //===----------------------------------------------------------------------===// +// ImOp +//===----------------------------------------------------------------------===// + +def ImOp : Std_Op<"im", + [NoSideEffect, + TypesMatchWith<"complex element type matches result type", + "complex", "imaginary", + "$_self.cast().getElementType()">]> { + let summary = "extracts the imaginary part of a complex number"; + let description = [{ + The `im` operation takes a single complex number as its operand and extracts + the imaginary part as a floating-point value. + + Example: + + ```mlir + %a = im %b : complex + ``` + }]; + + let arguments = (ins Complex:$complex); + let results = (outs AnyFloat:$imaginary); + + let assemblyFormat = "$complex attr-dict `:` type($complex)"; + + // `ImOp` is fully verified by its traits. + let verifier = ?; +} + +//===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -1878,6 +1942,36 @@ } //===----------------------------------------------------------------------===// +// ReOp +//===----------------------------------------------------------------------===// + +def ReOp : Std_Op<"re", + [NoSideEffect, + TypesMatchWith<"complex element type matches result type", + "complex", "real", + "$_self.cast().getElementType()">]> { + let summary = "extracts the real part of a complex number"; + let description = [{ + The `re` operation takes a single complex number as its operand and extracts + the real part as a floating-point value. + + Example: + + ```mlir + %a = re %b : complex + ``` + }]; + + let arguments = (ins Complex:$complex); + let results = (outs AnyFloat:$real); + + let assemblyFormat = "$complex attr-dict `:` type($complex)"; + + // `ReOp` is fully verified by its traits. + let verifier = ?; +} + +//===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -138,6 +138,7 @@ module->getDataLayout().getPointerSizeInBits(); // Register conversions for the standard types. + addConversion([&](ComplexType type) { return convertComplexType(type); }); addConversion([&](FloatType type) { return convertFloatType(type); }); addConversion([&](FunctionType type) { return convertFunctionType(type); }); addConversion([&](IndexType type) { return convertIndexType(type); }); @@ -191,6 +192,17 @@ } } +// Convert a `ComplexType` to an LLVM type. The result is a complex number +// struct with entries for the +// 1. real part and for the +// 2. imaginary part. +static constexpr unsigned kRealPosInComplexNumberStruct = 0; +static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; +Type LLVMTypeConverter::convertComplexType(ComplexType type) { + auto elementType = convertType(type.getElementType()).cast(); + return LLVM::LLVMType::getStructTy(llvmDialect, {elementType, elementType}); +} + // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { @@ -392,6 +404,7 @@ /*============================================================================*/ /* StructBuilder implementation */ /*============================================================================*/ + StructBuilder::StructBuilder(Value v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value.getType().dyn_cast(); @@ -410,6 +423,35 @@ value = builder.create(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } + +/*============================================================================*/ +/* ComplexStructBuilder implementation */ +/*============================================================================*/ + +ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, + Location loc, Type type) { + Value val = builder.create(loc, type.cast()); + return ComplexStructBuilder(val); +} + +void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, + Value real) { + setPtr(builder, loc, kRealPosInComplexNumberStruct, real); +} + +Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kRealPosInComplexNumberStruct); +} + +void ComplexStructBuilder ::setImaginary(OpBuilder &builder, Location loc, + Value imaginary) { + setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); +} + +Value ComplexStructBuilder ::imaginary(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); +} + /*============================================================================*/ /* MemRefDescriptor implementation */ /*============================================================================*/ @@ -1284,6 +1326,65 @@ OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; +// Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and +// `ImOp`. + +struct CreateComplexOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto complexOp = cast(op); + OperandAdaptor transformed(operands); + + // Pack real and imaginary part in a complex number struct. + auto loc = op->getLoc(); + auto structType = typeConverter.convertType(complexOp.getType()); + auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); + complexStruct.setReal(rewriter, loc, transformed.real()); + complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); + + rewriter.replaceOp(op, {complexStruct}); + return success(); + } +}; + +struct ReOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor transformed(operands); + + // Extract real part from the complex number struct. + ComplexStructBuilder complexStruct(transformed.complex()); + Value real = complexStruct.real(rewriter, op->getLoc()); + rewriter.replaceOp(op, real); + + return success(); + } +}; + +struct ImOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor transformed(operands); + + // Extract imaginary part from the complex number struct. + ComplexStructBuilder complexStruct(transformed.complex()); + Value imaginary = complexStruct.imaginary(rewriter, op->getLoc()); + rewriter.replaceOp(op, imaginary); + + return success(); + } +}; + // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { @@ -2896,6 +2997,7 @@ CopySignOpLowering, CosOpLowering, ConstLLVMOpLowering, + CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, ExpOpLowering, @@ -2906,12 +3008,14 @@ Log2OpLowering, FPExtLowering, FPTruncLowering, + ImOpLowering, IndexCastOpLowering, MulFOpLowering, MulIOpLowering, NegFOpLowering, OrOpLowering, PrefetchOpLowering, + ReOpLowering, RemFOpLowering, ReturnOpLowering, RsqrtOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -65,6 +65,24 @@ return } +// CHECK-LABEL: llvm.func @complex_numbers() +// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : !llvm.float +// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : !llvm.float +// CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm<"{ float, float }"> +// CHECK-NEXT: llvm.return +func @complex_numbers() { + %real0 = constant 1.2 : f32 + %imag0 = constant 3.4 : f32 + %cplx2 = create_complex %real0, %imag0 : complex + %real1 = re %cplx2 : complex + %imag1 = im %cplx2 : complex + return +} + // CHECK-LABEL: func @simple_caller() { // CHECK-NEXT: llvm.call @simple_loop() : () -> () // CHECK-NEXT: llvm.return @@ -367,6 +385,12 @@ func @get_i64() -> (i64) // CHECK-LABEL: func @get_f32() -> !llvm.float func @get_f32() -> (f32) +// CHECK-LABEL: func @get_c16() -> !llvm<"{ half, half }"> +func @get_c16() -> (complex) +// CHECK-LABEL: func @get_c32() -> !llvm<"{ float, float }"> +func @get_c32() -> (complex) +// CHECK-LABEL: func @get_c64() -> !llvm<"{ double, double }"> +func @get_c64() -> (complex) // CHECK-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK32-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i32, [4 x i32], [4 x i32] }"> func @get_memref() -> (memref<42x?x10x?xf32>) diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -86,6 +86,24 @@ // CHECK: %13 = muli %4, %4 : i32 %i6 = muli %i2, %i2 : i32 + // CHECK: %[[C0:.*]] = create_complex %[[F2:.*]], %[[F2]] : complex + %c0 = "std.create_complex"(%f2, %f2) : (f32, f32) -> complex + + // CHECK: %[[C1:.*]] = create_complex %[[F2]], %[[F2]] : complex + %c1 = create_complex %f2, %f2 : complex + + // CHECK: %[[REAL0:.*]] = re %[[CPLX0:.*]] : complex + %real0 = "std.re"(%c0) : (complex) -> f32 + + // CHECK: %[[REAL1:.*]] = re %[[CPLX0]] : complex + %real1 = re %c0 : complex + + // CHECK: %[[IMAG0:.*]] = im %[[CPLX0]] : complex + %imag0 = "std.im"(%c0) : (complex) -> f32 + + // CHECK: %[[IMAG1:.*]] = im %[[CPLX0]] : complex + %imag1 = im %c0 : complex + // CHECK: %c42_i32 = constant 42 : i32 %x = "std.constant"(){value = 42 : i32} : () -> i32 diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1220,3 +1220,47 @@ // expected-error@-1 {{requires an ancestor op with AutomaticAllocationScope trait}} return }) : () -> () + +// ----- + +func @complex_number_from_non_float_operands(%real: i32, %imag: i32) { + // expected-error@+1 {{'complex' must be complex type with floating-point elements, but got 'complex'}} + std.create_complex %real, %imag : complex + return +} + +// ----- + +// expected-note@+1 {{prior use here}} +func @complex_number_from_different_float_types(%real: f32, %imag: f64) { + // expected-error@+1 {{expects different type than prior uses: 'f32' vs 'f64'}} + std.create_complex %real, %imag : complex + return +} + +// ----- + +// expected-note@+1 {{prior use here}} +func @complex_number_from_incompatible_float_type(%real: f32, %imag: f32) { + // expected-error@+1 {{expects different type than prior uses: 'f64' vs 'f32'}} + std.create_complex %real, %imag : complex + return +} + +// ----- + +// expected-note@+1 {{prior use here}} +func @real_part_from_incompatible_complex_type(%cplx: complex) { + // expected-error@+1 {{expects different type than prior uses: 'complex' vs 'complex'}} + std.re %cplx : complex + return +} + +// ----- + +// expected-note@+1 {{prior use here}} +func @imaginary_part_from_incompatible_complex_type(%cplx: complex) { + // expected-error@+1 {{expects different type than prior uses: 'complex' vs 'complex'}} + std.re %cplx : complex + return +}