diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -20,7 +20,6 @@ using std_alloc = ValueBuilder; using std_alloca = ValueBuilder; using std_call = OperationBuilder; -using std_create_complex = ValueBuilder; using std_constant = ValueBuilder; using std_constant_float = ValueBuilder; using std_constant_index = ValueBuilder; @@ -31,12 +30,10 @@ using std_dim = ValueBuilder; using std_fpext = ValueBuilder; using std_fptrunc = ValueBuilder; -using std_im = ValueBuilder; using std_index_cast = ValueBuilder; using std_muli = ValueBuilder; using std_mulf = ValueBuilder; using std_memref_cast = ValueBuilder; -using std_re = ValueBuilder; using std_ret = OperationBuilder; using std_rsqrt = ValueBuilder; using std_select = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -149,18 +149,6 @@ [DeclareOpInterfaceMethods])>, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; -// Base class for standard arithmetic operations on complex numbers with a -// floating-point element type. -// These operations take two operands and return one result, all of which must -// be complex numbers of the same type. -// The assembly format is as follows -// -// cf %0, %1 : complex -// -class ComplexFloatArithmeticOp traits = []> : - ArithmeticOp, - Arguments<(ins Complex:$lhs, Complex:$rhs)>; - // Base class for memref allocating ops: alloca and alloc. // // %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)> @@ -263,26 +251,6 @@ }]; } -//===----------------------------------------------------------------------===// -// AddCFOp -//===----------------------------------------------------------------------===// - -def AddCFOp : ComplexFloatArithmeticOp<"addcf"> { - let summary = "complex number addition"; - let description = [{ - The `addcf` operation takes two complex number operands and returns their - sum, a single complex number. - All operands and result must be of the same type, a complex number with a - floating-point element type. - - Example: - - ```mlir - %a = addcf %b, %c : complex - ``` - }]; -} - //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -1178,40 +1146,6 @@ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } -//===----------------------------------------------------------------------===// -// CreateComplexOp -//===----------------------------------------------------------------------===// - -def CreateComplexOp : Std_Op<"create_complex", - [NoSideEffect, - AllTypesMatch<["real", "imaginary"]>, - TypesMatchWith<"complex element type matches real operand type", - "complex", "real", - "$_self.cast().getElementType()">, - TypesMatchWith<"complex element type matches imaginary operand type", - "complex", "imaginary", - "$_self.cast().getElementType()">]> { - let summary = "creates a complex number"; - let description = [{ - The `create_complex` operation creates a complex number from two - floating-point operands, the real and the imaginary part. - - Example: - - ```mlir - %a = create_complex %b, %c : complex - ``` - }]; - - let arguments = (ins AnyFloat:$real, AnyFloat:$imaginary); - let results = (outs Complex:$complex); - - let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)"; - - // `CreateComplexOp` is fully verified by its traits. - let verifier = ?; -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -1807,36 +1741,6 @@ let verifier = ?; } -//===----------------------------------------------------------------------===// -// ImOp -//===----------------------------------------------------------------------===// - -def ImOp : Std_Op<"im", - [NoSideEffect, - TypesMatchWith<"complex element type matches result type", - "complex", "imaginary", - "$_self.cast().getElementType()">]> { - let summary = "extracts the imaginary part of a complex number"; - let description = [{ - The `im` operation takes a single complex number as its operand and extracts - the imaginary part as a floating-point value. - - Example: - - ```mlir - %a = im %b : complex - ``` - }]; - - let arguments = (ins Complex:$complex); - let results = (outs AnyFloat:$imaginary); - - let assemblyFormat = "$complex attr-dict `:` type($complex)"; - - // `ImOp` is fully verified by its traits. - let verifier = ?; -} - //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -2414,36 +2318,6 @@ let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; } -//===----------------------------------------------------------------------===// -// ReOp -//===----------------------------------------------------------------------===// - -def ReOp : Std_Op<"re", - [NoSideEffect, - TypesMatchWith<"complex element type matches result type", - "complex", "real", - "$_self.cast().getElementType()">]> { - let summary = "extracts the real part of a complex number"; - let description = [{ - The `re` operation takes a single complex number as its operand and extracts - the real part as a floating-point value. - - Example: - - ```mlir - %a = re %b : complex - ``` - }]; - - let arguments = (ins Complex:$complex); - let results = (outs AnyFloat:$real); - - let assemblyFormat = "$complex attr-dict `:` type($complex)"; - - // `ReOp` is fully verified by its traits. - let verifier = ?; -} - //===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// @@ -2939,26 +2813,6 @@ }]; } -//===----------------------------------------------------------------------===// -// SubCFOp -//===----------------------------------------------------------------------===// - -def SubCFOp : ComplexFloatArithmeticOp<"subcf"> { - let summary = "complex number subtraction"; - let description = [{ - The `subcf` operation takes two complex number operands and returns their - difference, a single complex number. - All operands and result must be of the same type, a complex number with a - floating-point element type. - - Example: - - ```mlir - %a = subcf %b, %c : complex - ``` - }]; -} - //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1731,142 +1731,6 @@ } }; -// Lowerings for operations on complex numbers. - -struct CreateComplexOpLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CreateComplexOp complexOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - CreateComplexOp::Adaptor transformed(operands); - - // Pack real and imaginary part in a complex number struct. - auto loc = complexOp.getLoc(); - auto structType = typeConverter->convertType(complexOp.getType()); - auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); - complexStruct.setReal(rewriter, loc, transformed.real()); - complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); - - rewriter.replaceOp(complexOp, {complexStruct}); - return success(); - } -}; - -struct ReOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ReOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - ReOp::Adaptor transformed(operands); - - // Extract real part from the complex number struct. - ComplexStructBuilder complexStruct(transformed.complex()); - Value real = complexStruct.real(rewriter, op.getLoc()); - rewriter.replaceOp(op, real); - - return success(); - } -}; - -struct ImOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ImOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - ImOp::Adaptor transformed(operands); - - // Extract imaginary part from the complex number struct. - ComplexStructBuilder complexStruct(transformed.complex()); - Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); - rewriter.replaceOp(op, imaginary); - - return success(); - } -}; - -struct BinaryComplexOperands { - std::complex lhs, rhs; -}; - -template -BinaryComplexOperands -unpackBinaryComplexOperands(OpTy op, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op.getLoc(); - typename OpTy::Adaptor transformed(operands); - - // Extract real and imaginary values from operands. - BinaryComplexOperands unpacked; - ComplexStructBuilder lhs(transformed.lhs()); - unpacked.lhs.real(lhs.real(rewriter, loc)); - unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); - ComplexStructBuilder rhs(transformed.rhs()); - unpacked.rhs.real(rhs.real(rewriter, loc)); - unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); - - return unpacked; -} - -struct AddCFOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(AddCFOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - BinaryComplexOperands arg = - unpackBinaryComplexOperands(op, operands, rewriter); - - // Initialize complex number struct for result. - auto structType = typeConverter->convertType(op.getType()); - auto result = ComplexStructBuilder::undef(rewriter, loc, structType); - - // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get({}, op.getContext()); - Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); - Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); - result.setReal(rewriter, loc, real); - result.setImaginary(rewriter, loc, imag); - - rewriter.replaceOp(op, {result}); - return success(); - } -}; - -struct SubCFOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(SubCFOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - BinaryComplexOperands arg = - unpackBinaryComplexOperands(op, operands, rewriter); - - // Initialize complex number struct for result. - auto structType = typeConverter->convertType(op.getType()); - auto result = ComplexStructBuilder::undef(rewriter, loc, structType); - - // Emit IR to substract complex numbers. - auto fmf = LLVM::FMFAttr::get({}, op.getContext()); - Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); - Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); - result.setReal(rewriter, loc, real); - result.setImaginary(rewriter, loc, imag); - - rewriter.replaceOp(op, {result}); - return success(); - } -}; - struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3910,7 +3774,6 @@ // clang-format off patterns.insert< AbsFOpLowering, - AddCFOpLowering, AddFOpLowering, AddIOpLowering, AllocaOpLowering, @@ -3927,7 +3790,6 @@ CopySignOpLowering, CosOpLowering, ConstantOpLowering, - CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, ExpOpLowering, @@ -3941,7 +3803,6 @@ FPToSILowering, FPToUILowering, FPTruncLowering, - ImOpLowering, IndexCastOpLowering, MulFOpLowering, MulIOpLowering, @@ -3949,7 +3810,6 @@ OrOpLowering, PowFOpLowering, PrefetchOpLowering, - ReOpLowering, RemFOpLowering, ReturnOpLowering, RsqrtOpLowering, @@ -3964,7 +3824,6 @@ SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, - SubCFOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -65,66 +65,6 @@ return } -// CHECK-LABEL: llvm.func @complex_numbers() -// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32 -// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32 -// CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: llvm.return -func @complex_numbers() { - %real0 = constant 1.2 : f32 - %imag0 = constant 3.4 : f32 - %cplx2 = create_complex %real0, %imag0 : complex - %real1 = re %cplx2 : complex - %imag1 = im %cplx2 : complex - return -} - -// CHECK-LABEL: llvm.func @complex_addition() -// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> -// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : f64 -// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : f64 -// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> -// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> -func @complex_addition() { - %a_re = constant 1.2 : f64 - %a_im = constant 3.4 : f64 - %a = create_complex %a_re, %a_im : complex - %b_re = constant 5.6 : f64 - %b_im = constant 7.8 : f64 - %b = create_complex %b_re, %b_im : complex - %c = addcf %a, %b : complex - return -} - -// CHECK-LABEL: llvm.func @complex_substraction() -// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> -// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : f64 -// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : f64 -// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> -// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> -func @complex_substraction() { - %a_re = constant 1.2 : f64 - %a_im = constant 3.4 : f64 - %a = create_complex %a_re, %a_im : complex - %b_re = constant 5.6 : f64 - %b_im = constant 7.8 : f64 - %b = create_complex %b_re, %b_im : complex - %c = subcf %a, %b : complex - return -} - // CHECK-LABEL: func @simple_caller() { // CHECK-NEXT: llvm.call @simple_loop() : () -> () // CHECK-NEXT: llvm.return diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -89,24 +89,6 @@ // CHECK: %[[F7:.*]] = powf %[[F2]], %[[F2]] : f32 %f7 = powf %f2, %f2 : f32 - // CHECK: %[[C0:.*]] = create_complex %[[F2]], %[[F2]] : complex - %c0 = "std.create_complex"(%f2, %f2) : (f32, f32) -> complex - - // CHECK: %[[C1:.*]] = create_complex %[[F2]], %[[F2]] : complex - %c1 = create_complex %f2, %f2 : complex - - // CHECK: %[[REAL0:.*]] = re %[[CPLX0:.*]] : complex - %real0 = "std.re"(%c0) : (complex) -> f32 - - // CHECK: %[[REAL1:.*]] = re %[[CPLX0]] : complex - %real1 = re %c0 : complex - - // CHECK: %[[IMAG0:.*]] = im %[[CPLX0]] : complex - %imag0 = "std.im"(%c0) : (complex) -> f32 - - // CHECK: %[[IMAG1:.*]] = im %[[CPLX0]] : complex - %imag1 = im %c0 : complex - // CHECK: %c42_i32 = constant 42 : i32 %x = "std.constant"(){value = 42 : i32} : () -> i32 diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1173,50 +1173,6 @@ // ----- -func @complex_number_from_non_float_operands(%real: i32, %imag: i32) { - // expected-error@+1 {{'complex' must be complex type with floating-point elements, but got 'complex'}} - std.create_complex %real, %imag : complex - return -} - -// ----- - -// expected-note@+1 {{prior use here}} -func @complex_number_from_different_float_types(%real: f32, %imag: f64) { - // expected-error@+1 {{expects different type than prior uses: 'f32' vs 'f64'}} - std.create_complex %real, %imag : complex - return -} - -// ----- - -// expected-note@+1 {{prior use here}} -func @complex_number_from_incompatible_float_type(%real: f32, %imag: f32) { - // expected-error@+1 {{expects different type than prior uses: 'f64' vs 'f32'}} - std.create_complex %real, %imag : complex - return -} - -// ----- - -// expected-note@+1 {{prior use here}} -func @real_part_from_incompatible_complex_type(%cplx: complex) { - // expected-error@+1 {{expects different type than prior uses: 'complex' vs 'complex'}} - std.re %cplx : complex - return -} - -// ----- - -// expected-note@+1 {{prior use here}} -func @imaginary_part_from_incompatible_complex_type(%cplx: complex) { - // expected-error@+1 {{expects different type than prior uses: 'complex' vs 'complex'}} - std.re %cplx : complex - return -} - -// ----- - func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) { // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}} %0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]