diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -870,4 +870,14 @@ let verifier = "return ::verify(*this);"; } +def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>, + Arguments<(ins LLVM_Type:$cond)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = + llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {}); + builder.CreateCall(fn, {$cond}); + }]; +} + #endif // LLVMIR_OPS diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1639,4 +1639,38 @@ }]; } +def AssumeAlignOp : Std_Op<"assume_align"> { + let summary = + "assertion that gives alignment information to the input memref"; + let description = [{ + The assume alignment operation takes a memref and a integer of alignment + value, and internally annotates the buffer with the given alignment. If + the buffer isn't aligned to the given alignment, the behavior is undefined. + + This operation doesn't affect the semantics of a correct program. It's for + optimization only, and the optimization is best-effort. + }]; + + let arguments = (ins AnyMemRef:$memref); + let results = (outs); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value memref, unsigned alignment", [{ + result.addOperands(memref); + result.addAttribute( + getAlignmentAttrName(), + IntegerAttr::get(builder->getIndexType(), alignment)); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getAlignmentAttrName() { return "align"; } + + unsigned getAlignment() { + return getAttrOfType(getAlignmentAttrName()) + .getValue() + .getZExtValue(); + } + }]; +} + #endif // STANDARD_OPS diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2507,6 +2507,46 @@ } }; +struct AssumeAlignOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor transformed(operands); + Value memref = transformed.memref(); + unsigned alignment = cast(op).getAlignment(); + + MemRefDescriptor memRefDescriptor(memref); + Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); + + // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that + // the asserted memref.alignedPtr isn't used anywhere else, as the real + // users like load/store/views always re-extract memref.alignedPtr as they + // get lowered. + // + // This relies on LLVM's CSE optimization (potentially after SROA), since + // after CSE all memref.alignedPtr instances get de-duplicated into the same + // pointer SSA value. + Value zero = + createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0); + Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), + alignment - 1); + rewriter.create( + op->getLoc(), + rewriter.create( + op->getLoc(), LLVM::ICmpPredicate::eq, + rewriter.create(op->getLoc(), + rewriter.create( + op->getLoc(), getIndexType(), ptr), + mask), + zero)); + + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + } // namespace static void ensureDistinctSuccessors(Block &bb) { @@ -2622,6 +2662,7 @@ LoadOpLowering, MemRefCastOpLowering, StoreOpLowering, + AssumeAlignOpLowering, SubViewOpLowering, ViewOpLowering>(*converter.getDialect(), converter); patterns.insert< diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2764,6 +2764,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// AssumeAlignOp +//===----------------------------------------------------------------------===// + +// assume_align `memref` { align = `alignment` } : memref_type +static ParseResult parseAssumeAlignOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType memref; + MemRefType type; + return failure(parser.parseOperand(memref) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(memref, type, result.operands)); +} + +static void print(OpAsmPrinter &p, AssumeAlignOp op) { + p << op.getOperationName() << ' ' << op.getOperand() << ' '; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand().getType(); +} + +static LogicalResult verify(AssumeAlignOp op) { + if (!op.getAttrOfType(op.getAlignmentAttrName())) + return op.emitOpError("missing integer attribute `align`"); + unsigned align = op.getAlignment(); + if ((align & (align - 1)) != 0 || align == 0) + return op.emitOpError("alignment must be power of 2 and positive"); + return success(); +} + namespace { /// Pattern to rewrite a subview op with constant size arguments. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -855,3 +855,18 @@ // CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float // CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table } + +// ----- + +// CHECK-LABEL: func @assume_align +func @assume_align(%0 : memref<4x4xf16>) { + // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64 + // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm<"half*"> to !llvm.i64 + // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : !llvm.i64 + // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : !llvm.i64 + // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (!llvm.i1) -> () + std.assume_align %0 { align = 16 } : memref<4x4xf16> + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -740,3 +740,11 @@ tensor_store %1, %0 : memref<4x4xi32> return } + +// CHECK-LABEL: func @assume_align +// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16> +func @assume_align(%0: memref<4x4xf16>) { + // CHECK: std.assume_align %[[MEMREF]] {align = 16 : i64} : memref<4x4xf16> + std.assume_align %0 {align = 16} : memref<4x4xf16> + return +} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1036,3 +1036,21 @@ %2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0> return } + +// ----- + +// alignment is not power of 2. +func @assume_align(%0: memref<4x4xf16>) { + // expected-error@+1 {{alignment must be power of 2}} + std.assume_align %0 {align = 12} : memref<4x4xf16> + return +} + +// ----- + +// missing align value. +func @assume_align(%0: memref<4x4xf16>) { + // expected-error@+1 {{missing integer attribute `align`}} + std.assume_align %0 : memref<4x4xf16> + return +}