diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -327,6 +327,7 @@ let assemblyFormat = [{ $base `[` $indices `]` attr-dict `:` functional-type(operands, results) }]; + let hasFolder = 1; } def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { @@ -398,10 +399,14 @@ let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; } def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", - LLVM_AnyNonAggregate, LLVM_AnyNonAggregate>; + LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> { + let hasFolder = 1; +} def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast", LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf> { + let hasFolder = 1; +} def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" @@ -2259,6 +2260,48 @@ return success(); } +//===----------------------------------------------------------------------===// +// Folder for LLVM::BitcastOp +//===----------------------------------------------------------------------===// + +OpFoldResult LLVM::BitcastOp::fold(ArrayRef operands) { + // bitcast(x : T0, T0) -> x + if (getArg().getType() == getType()) + return getArg(); + // bitcast(bitcast(x : T0, T1), T0) -> x + if (auto prev = getArg().getDefiningOp()) + if (prev.getArg().getType() == getType()) + return prev.getArg(); + return {}; +} + +//===----------------------------------------------------------------------===// +// Folder for LLVM::AddrSpaceCastOp +//===----------------------------------------------------------------------===// + +OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef operands) { + // addrcast(x : T0, T0) -> x + if (getArg().getType() == getType()) + return getArg(); + // addrcast(addrcast(x : T0, T1), T0) -> x + if (auto prev = getArg().getDefiningOp()) + if (prev.getArg().getType() == getType()) + return prev.getArg(); + return {}; +} + +//===----------------------------------------------------------------------===// +// Folder for LLVM::GEPOp +//===----------------------------------------------------------------------===// + +OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { + // gep %x:T, 0 -> %x + if (getBase().getType() == getType() && getIndices().size() == 1 && + matchPattern(getIndices()[0], m_Zero())) + return getBase(); + return {}; +} + //===----------------------------------------------------------------------===// // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -38,3 +38,52 @@ llvm.return %3 : f32 } + +// ----- +// CHECK-LABEL: fold_bitcast +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: llvm.return %[[a0]] +llvm.func @fold_bitcast(%x : !llvm.ptr) -> !llvm.ptr { + %c = llvm.bitcast %x : !llvm.ptr to !llvm.ptr + llvm.return %c : !llvm.ptr +} + +// CHECK-LABEL: fold_bitcast2 +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: llvm.return %[[a0]] +llvm.func @fold_bitcast2(%x : !llvm.ptr) -> !llvm.ptr { + %c = llvm.bitcast %x : !llvm.ptr to !llvm.ptr + %d = llvm.bitcast %c : !llvm.ptr to !llvm.ptr + llvm.return %d : !llvm.ptr +} + +// ----- + +// CHECK-LABEL: fold_addrcast +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: llvm.return %[[a0]] +llvm.func @fold_addrcast(%x : !llvm.ptr) -> !llvm.ptr { + %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr + llvm.return %c : !llvm.ptr +} + +// CHECK-LABEL: fold_addrcast2 +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: llvm.return %[[a0]] +llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr { + %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr + %d = llvm.addrspacecast %c : !llvm.ptr to !llvm.ptr + llvm.return %d : !llvm.ptr +} + +// ----- + +// CHECK-LABEL: fold_gep +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: llvm.return %[[a0]] +llvm.func @fold_gep(%x : !llvm.ptr) -> !llvm.ptr { + %c0 = arith.constant 0 : i32 + %c = llvm.getelementptr %x[%c0] : (!llvm.ptr, i32) -> !llvm.ptr + llvm.return %c : !llvm.ptr +} +