diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -870,4 +870,14 @@ let verifier = "return ::verify(*this);"; } +def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>, + Arguments<(ins LLVM_Type:$cond)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = + llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {}); + builder.CreateCall(fn, {$cond}); + }]; +} + #endif // LLVMIR_OPS diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1639,4 +1639,21 @@ }]; } +def AssumeAlignmentOp : Std_Op<"assume_alignment"> { + let summary = + "assertion that gives alignment information to the input memref"; + let description = [{ + The assume alignment operation takes a memref and a integer of alignment + value, and internally annotates the buffer with the given alignment. If + the buffer isn't aligned to the given alignment, the behavior is undefined. + + This operation doesn't affect the semantics of a correct program. It's for + optimization only, and the optimization is best-effort. + }]; + let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment); + let results = (outs); + + let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; +} + #endif // STANDARD_OPS diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2501,6 +2501,45 @@ } }; +struct AssumeAlignmentOpLowering + : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor transformed(operands); + Value memref = transformed.memref(); + unsigned alignment = cast(op).alignment().getZExtValue(); + + MemRefDescriptor memRefDescriptor(memref); + Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); + + // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that + // the asserted memref.alignedPtr isn't used anywhere else, as the real + // users like load/store/views always re-extract memref.alignedPtr as they + // get lowered. + // + // This relies on LLVM's CSE optimization (potentially after SROA), since + // after CSE all memref.alignedPtr instances get de-duplicated into the same + // pointer SSA value. + Value zero = + createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0); + Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), + alignment - 1); + Value ptrValue = + rewriter.create(op->getLoc(), getIndexType(), ptr); + rewriter.create( + op->getLoc(), + rewriter.create( + op->getLoc(), LLVM::ICmpPredicate::eq, + rewriter.create(op->getLoc(), ptrValue, mask), zero)); + + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + } // namespace static void ensureDistinctSuccessors(Block &bb) { @@ -2612,6 +2651,7 @@ bool useAlloca) { // clang-format off patterns.insert< + AssumeAlignmentOpLowering, DimOpLowering, LoadOpLowering, MemRefCastOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2764,6 +2764,17 @@ return success(); } +//===----------------------------------------------------------------------===// +// AssumeAlignmentOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AssumeAlignmentOp op) { + unsigned alignment = op.alignment().getZExtValue(); + if (!llvm::isPowerOf2_32(alignment)) + return op.emitOpError("alignment must be power of 2"); + return success(); +} + namespace { /// Pattern to rewrite a subview op with constant size arguments. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -855,3 +855,18 @@ // CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float // CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table } + +// ----- + +// CHECK-LABEL: func @assume_alignment +func @assume_alignment(%0 : memref<4x4xf16>) { + // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64 + // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm<"half*"> to !llvm.i64 + // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : !llvm.i64 + // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : !llvm.i64 + // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (!llvm.i1) -> () + assume_alignment %0, 16 : memref<4x4xf16> + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -740,3 +740,11 @@ tensor_store %1, %0 : memref<4x4xi32> return } + +// CHECK-LABEL: func @assume_alignment +// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16> +func @assume_alignment(%0: memref<4x4xf16>) { + // CHECK: assume_alignment %[[MEMREF]], 16 : memref<4x4xf16> + assume_alignment %0, 16 : memref<4x4xf16> + return +} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1036,3 +1036,21 @@ %2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0> return } + +// ----- + +// alignment is not power of 2. +func @assume_alignment(%0: memref<4x4xf16>) { + // expected-error@+1 {{alignment must be power of 2}} + std.assume_alignment %0, 12 : memref<4x4xf16> + return +} + +// ----- + +// 0 alignment value. +func @assume_alignment(%0: memref<4x4xf16>) { + // expected-error@+1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: positive 32-bit integer attribute}} + std.assume_alignment %0, 0 : memref<4x4xf16> + return +}