diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -15,6 +15,7 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H +#include "mlir/IR/StandardTypes.h" #include "mlir/Transforms/DialectConversion.h" namespace llvm { @@ -157,6 +158,11 @@ // by LLVM. Type convertFloatType(FloatType type); + // Convert complex number type: `complex` to `!llvm<"{ half, half }">`, + // `complex` to `!llvm<"{ float, float }">`, and `complex` to + // `!llvm<"{ double, double }">`. `complex` is not supported. + Type convertComplexType(ComplexType type); + /// Convert a memref type into an LLVM type that captures the relevant data. Type convertMemRefType(MemRefType type); @@ -476,8 +482,8 @@ } }; -/// Basic lowering implementation for rewriting from Ops to LLVM Dialect Ops -/// with one result. This supports higher-dimensional vector types. +/// Basic lowering implementation to rewrite Ops with just one result to the +/// LLVM Dialect. This supports higher-dimensional vector types. template class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -18,6 +18,7 @@ using std_addf = ValueBuilder; using std_alloc = ValueBuilder; using std_call = OperationBuilder; +using std_create_complex = ValueBuilder; using std_constant = ValueBuilder; using std_constant_float = ValueBuilder; using std_constant_index = ValueBuilder; @@ -25,10 +26,12 @@ using std_dealloc = OperationBuilder; using std_dim = ValueBuilder; using std_extract_element = ValueBuilder; +using std_im = ValueBuilder; using std_index_cast = ValueBuilder; using std_muli = ValueBuilder; using std_mulf = ValueBuilder; using std_memref_cast = ValueBuilder; +using std_real = ValueBuilder; using std_ret = OperationBuilder; using std_select = ValueBuilder; using std_load = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -986,6 +986,39 @@ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } +//===----------------------------------------------------------------------===// +// CreateComplexOp +//===----------------------------------------------------------------------===// + +def CreateComplexOp : Std_Op<"create_complex", [NoSideEffect]> { + let summary = "creates a complex number"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `std.create_complex` ssa-use `,` ssa-use `:` type + ``` + + The `create_complex` operation creates a complex number from two + floating-point operands, the real and the imaginary part. + + Example: + + ```mlir + %a = create_complex %b, %c : f32, f32 -> complex + ``` + }]; + + let arguments = (ins AnyFloat:$real, AnyFloat:$imag); + let results = (outs Complex:$cplx); + + let assemblyFormat = [{ + $real `,` $imag attr-dict `:` type($real) `,` type($imag) `->` type($cplx) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -1475,6 +1508,39 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// ImOp +//===----------------------------------------------------------------------===// + +def ImOp : Std_Op<"im", [NoSideEffect]> { + let summary = "extracts the imaginary part of a complex number"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `std.im` ssa-use `:` type + ``` + + The `im` operation extracts the imaginary part of a complex number as a + floating-point value. + + Example: + + ```mlir + %a = im %b, %c : complex -> f32 + ``` + }]; + + let arguments = (ins Complex:$cplx); + let results = (outs AnyFloat:$imag); + + let assemblyFormat = [{ + $cplx attr-dict `:` type($cplx) `->` type($imag) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -1855,6 +1921,39 @@ let assemblyFormat = "operands attr-dict `:` type(operands)"; } +//===----------------------------------------------------------------------===// +// ReOp +//===----------------------------------------------------------------------===// + +def ReOp : Std_Op<"re", [NoSideEffect]> { + let summary = "extracts the real part of a complex number"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `std.re` ssa-use `:` type + ``` + + The `re` operation extracts the real part of a complex number as a + floating-point value. + + Example: + + ```mlir + %a = re %b, %c : complex -> f32 + ``` + }]; + + let arguments = (ins Complex:$cplx); + let results = (outs AnyFloat:$real); + + let assemblyFormat = [{ + $cplx attr-dict `:` type($cplx) `->` type($real) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + //===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -138,6 +138,7 @@ module->getDataLayout().getPointerSizeInBits(); // Register conversions for the standard types. + addConversion([&](ComplexType type) { return convertComplexType(type); }); addConversion([&](FloatType type) { return convertFloatType(type); }); addConversion([&](FunctionType type) { return convertFunctionType(type); }); addConversion([&](IndexType type) { return convertIndexType(type); }); @@ -191,6 +192,14 @@ } } +Type LLVMTypeConverter::convertComplexType(ComplexType type) { + const FloatType elementType = type.getElementType().dyn_cast(); + const LLVM::LLVMType llvmElementType = + convertFloatType(elementType).dyn_cast(); + return LLVM::LLVMType::getStructTy(llvmDialect, + {llvmElementType, llvmElementType}); +} + // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { @@ -1284,6 +1293,75 @@ OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; +// Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and +// `ImOp`. + +const int REAL_IDX_IN_CPLX_NUM_STRUCT = 0; +const int IMAG_IDX_IN_CPLX_NUM_STRUCT = 1; + +struct CreateComplexOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto complexOp = cast(op); + OperandAdaptor transformed(operands); + + // Pack real and imaginary part into one structure. + auto packedType = typeConverter.convertType(complexOp.getType()); + Value packed = rewriter.create(op->getLoc(), packedType); + packed = rewriter.create( + op->getLoc(), packedType, packed, transformed.real(), + rewriter.getI64ArrayAttr(REAL_IDX_IN_CPLX_NUM_STRUCT)); + packed = rewriter.create( + op->getLoc(), packedType, packed, transformed.imag(), + rewriter.getI64ArrayAttr(IMAG_IDX_IN_CPLX_NUM_STRUCT)); + + rewriter.replaceOp(op, packed); + return success(); + } +}; + +struct ReOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reOp = cast(op); + OperandAdaptor transformed(operands); + + // Extract real part from packed structure. + auto reType = typeConverter.convertType(reOp.getType()); + rewriter.replaceOpWithNewOp( + op, reType, transformed.cplx(), + rewriter.getI64ArrayAttr(REAL_IDX_IN_CPLX_NUM_STRUCT)); + + return success(); + } +}; + +struct ImOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto imOp = cast(op); + OperandAdaptor transformed(operands); + + // Extract imaginary part from packed structure. + auto imType = typeConverter.convertType(imOp.getType()); + rewriter.replaceOpWithNewOp( + op, imType, transformed.cplx(), + rewriter.getI64ArrayAttr(IMAG_IDX_IN_CPLX_NUM_STRUCT)); + + return success(); + } +}; + // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { @@ -2894,6 +2972,7 @@ CopySignOpLowering, CosOpLowering, ConstLLVMOpLowering, + CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, ExpOpLowering, @@ -2904,12 +2983,14 @@ Log2OpLowering, FPExtLowering, FPTruncLowering, + ImOpLowering, IndexCastOpLowering, MulFOpLowering, MulIOpLowering, NegFOpLowering, OrOpLowering, PrefetchOpLowering, + ReOpLowering, RemFOpLowering, ReturnOpLowering, RsqrtOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1219,6 +1219,18 @@ ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } +//===----------------------------------------------------------------------===// +// CreateComplexOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(CreateComplexOp op) { + auto elementType = op.getType().dyn_cast().getElementType(); + if (op.real().getType() != elementType || op.imag().getType() != elementType) + return op.emitOpError( + "expected two operand types and complex element type to be the same"); + return success(); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// @@ -1627,6 +1639,19 @@ return false; } +//===----------------------------------------------------------------------===// +// ImOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ImOp op) { + const auto &complexElementType = + op.cplx().getType().dyn_cast().getElementType(); + if (complexElementType != op.imag().getType()) + return op.emitOpError( + "expected result type and complex element type to be the same"); + return success(); +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -1871,6 +1896,19 @@ return IntegerAttr(); } +//===----------------------------------------------------------------------===// +// ReOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReOp op) { + const auto &complexElementType = + op.cplx().getType().dyn_cast().getElementType(); + if (complexElementType != op.real().getType()) + return op.emitOpError( + "expected result type and complex element type to be the same"); + return success(); +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -65,6 +65,25 @@ return } +// CHECK-LABEL: llvm.func @complex_numbers() +// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : !llvm.float +// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : !llvm.float +// CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm<"{ float, float }"> +// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm<"{ float, float }"> +// CHECK-NEXT: llvm.return +func @complex_numbers() { +^bb0: + %real0 = constant 1.2 : f32 + %imag0 = constant 3.4 : f32 + %cplx2 = create_complex %real0, %imag0 : f32, f32 -> complex + %real1 = re %cplx2 : complex -> f32 + %imag1 = im %cplx2 : complex -> f32 + return +} + // CHECK-LABEL: func @simple_caller() { // CHECK-NEXT: llvm.call @simple_loop() : () -> () // CHECK-NEXT: llvm.return @@ -367,6 +386,12 @@ func @get_i64() -> (i64) // CHECK-LABEL: func @get_f32() -> !llvm.float func @get_f32() -> (f32) +// CHECK-LABEL: func @get_c16() -> !llvm<"{ half, half }"> +func @get_c16() -> (complex) +// CHECK-LABEL: func @get_c32() -> !llvm<"{ float, float }"> +func @get_c32() -> (complex) +// CHECK-LABEL: func @get_c64() -> !llvm<"{ double, double }"> +func @get_c64() -> (complex) // CHECK-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK32-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i32, [4 x i32], [4 x i32] }"> func @get_memref() -> (memref<42x?x10x?xf32>) diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -86,6 +86,24 @@ // CHECK: %13 = muli %4, %4 : i32 %i6 = muli %i2, %i2 : i32 + // CHECK: %[[C0:.*]] = create_complex %[[F2:.*]], %[[F2]] : f32, f32 -> complex + %c0 = "std.create_complex"(%f2, %f2) : (f32, f32) -> complex + + // CHECK: %[[C1:.*]] = create_complex %[[F2]], %[[F2]] : f32, f32 -> complex + %c1 = create_complex %f2, %f2 : f32, f32 -> complex + + // CHECK: %[[REAL0:.*]] = re %[[CPLX0:.*]] : complex -> f32 + %real0 = "std.re"(%c0) : (complex) -> f32 + + // CHECK: %[[REAL1:.*]] = re %[[CPLX0]] : complex -> f32 + %real1 = re %c0 : complex -> f32 + + // CHECK: %[[IMAG0:.*]] = im %[[CPLX0]] : complex -> f32 + %imag0 = "std.im"(%c0) : (complex) -> f32 + + // CHECK: %[[IMAG1:.*]] = im %[[CPLX0]] : complex -> f32 + %imag1 = im %c0 : complex -> f32 + // CHECK: %c42_i32 = constant 42 : i32 %x = "std.constant"(){value = 42 : i32} : () -> i32 diff --git a/mlir/test/mlir-cpu-runner/complex_numbers.mlir b/mlir/test/mlir-cpu-runner/complex_numbers.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/complex_numbers.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner -e create_and_extract_real -entry-point-result=f32 \ +// RUN: | FileCheck -check-prefix=CHECK_REAL %s + +// Create complex number and extracit real part. +func @create_and_extract_real() -> f32 { + %real = constant 1.2 : f32 + %imag = constant 2.3 : f32 + %cplx = create_complex %real, %imag : f32, f32 -> complex + %result = re %cplx : complex -> f32 + return %result : f32 +} +// CHECK_REAL: 1.200000e+00 + + + +// RUN: mlir-opt %s -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner -e create_and_extract_imag -entry-point-result=f32 \ +// RUN: | FileCheck -check-prefix=CHECK_IMAG %s + +// Create complex number and extract imaginary part. +func @create_and_extract_imag() -> f32 { + %real = constant 1.2 : f32 + %imag = constant 2.3 : f32 + %cplx = create_complex %real, %imag : f32, f32 -> complex + %result = im %cplx : complex -> f32 + return %result : f32 +} +// CHECK_IMAG: 2.300000e+00