diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -491,6 +491,81 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// StrideOp +//===----------------------------------------------------------------------===// + +def MemRef_StrideOp : MemRef_Op<"stride", [NoSideEffect, MemRefsNormalizable]> { + let summary = "dimension stride operation"; + let description = [{ + The `stride` operation takes a memref and a dimension operand of type `index`. + It returns the stride of the requested dimension of the given memref. + If the dimension index is out of bounds the behavior is undefined. + + The specified memref type is that of the first operand. + + Example: + + ```mlir + %c0 = arith.constant 0 : index + %x = memref.stride %A, %c0 : memref<4 x ? x f32> + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$source, + Index:$index); + let results = (outs Index:$result); + + let assemblyFormat = [{ + attr-dict $source `,` $index `:` type($source) + }]; + + let builders = [ + OpBuilder<(ins "Value":$source, "int64_t":$index)>, + OpBuilder<(ins "Value":$source, "Value":$index)> + ]; + + let extraClassDeclaration = [{ + /// Helper function to get the index as a simple integer if it is constant. + Optional getConstantIndex(); + }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// OffsetOp +//===----------------------------------------------------------------------===// + +def MemRef_OffsetOp : MemRef_Op<"offset", [NoSideEffect, MemRefsNormalizable]> { + let summary = "offset operation"; + let description = [{ + The `offset` operation takes a memref. It returns the offset of the given + memref. + + The specified memref type is that of the first operand. + + Example: + + ```mlir + %x = memref.offset %A : memref<4 x ? x f32> + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$source); + let results = (outs Index:$result); + + let assemblyFormat = [{ + attr-dict $source `:` type($source) + }]; + + let builders = [ + OpBuilder<(ins "Value":$source)>, + ]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // DmaStartOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -755,18 +755,19 @@ return {}; } -LogicalResult DimOp::verify() { +template +static LogicalResult verifyOpDimIndex(Op op) { // Assume unknown index to be in range. - Optional index = getConstantIndex(); + Optional index = op.getConstantIndex(); if (!index) return success(); // Check that constant index is not knowingly out of range. - auto type = getSource().getType(); - if (auto memrefType = type.dyn_cast()) { + auto type = op.getSource().getType(); + if (auto memrefType = type.template dyn_cast()) { if (*index >= memrefType.getRank()) - return emitOpError("index is out of range"); - } else if (type.isa()) { + return op.emitOpError("index is out of range"); + } else if (type.template isa()) { // Assume index to be in range. } else { llvm_unreachable("expected operand with memref type"); @@ -774,6 +775,8 @@ return success(); } +LogicalResult DimOp::verify() { return verifyOpDimIndex(*this); } + /// Return a map with key being elements in `vals` and data being number of /// occurences of it. Use std::map, since the `vals` here are strides and the /// dynamic stride value is the same as the tombstone value for @@ -971,6 +974,62 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// StrideOp +//===----------------------------------------------------------------------===// + +void StrideOp::build(OpBuilder &builder, OperationState &result, Value source, + int64_t index) { + auto loc = result.location; + Value indexValue = builder.create(loc, index); + build(builder, result, source, indexValue); +} + +void StrideOp::build(OpBuilder &builder, OperationState &result, Value source, + Value index) { + auto indexTy = builder.getIndexType(); + build(builder, result, indexTy, source, index); +} + +Optional StrideOp::getConstantIndex() { + if (auto constantOp = getIndex().getDefiningOp()) + return constantOp.getValue().cast().getInt(); + return {}; +} + +LogicalResult StrideOp::verify() { + int64_t offset; + SmallVector strides; + auto memrefType = getSource().getType().cast(); + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + emitOpError("invalid memref layout"); + return failure(); + } + + return verifyOpDimIndex(*this); +} + +//===----------------------------------------------------------------------===// +// OffsetOp +//===----------------------------------------------------------------------===// + +void OffsetOp::build(OpBuilder &builder, OperationState &result, Value source) { + auto indexTy = builder.getIndexType(); + build(builder, result, indexTy, source); +} + +LogicalResult OffsetOp::verify() { + int64_t offset; + SmallVector strides; + auto memrefType = getSource().getType().cast(); + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + emitOpError("invalid memref layout"); + return failure(); + } + + return success(); +} + // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -994,3 +994,26 @@ } return } + +// ----- + +func.func @stride_invalid_index(%t : memref<4x4x?xf32>) { + %c5 = arith.constant 5 : index + // expected-error@+1 {{index is out of range}} + %1 = memref.stride %t, %c5 : memref<4x4x?xf32> +} + +// ----- + +func.func @stride_invalid_layout(%t : memref (d0 mod 5)>>) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{invalid memref layout}} + %1 = memref.stride %t, %c0 : memref (d0 mod 5)>> +} + +// ----- + +func.func @offset_invalid_layout(%t : memref (d0 mod 5)>>) { + // expected-error@+1 {{invalid memref layout}} + %1 = memref.offset %t : memref (d0 mod 5)>> +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -336,3 +336,29 @@ } { index_attr = 8 : index } return } + +// ----- + +// CHECK-LABEL: func @stride +// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4x?xf32>, %[[D:.*]]: index) +func.func @stride(%t : memref<4x4x?xf32>, %d : index) { + // CHECK: %{{.*}} = memref.stride %[[MEMREF]], %[[D]] : memref<4x4x?xf32> + %0 = "memref.stride"(%t, %d) : (memref<4x4x?xf32>, index) -> index + + // CHECK: %{{.*}} = memref.stride %[[MEMREF]], %[[D]] : memref<4x4x?xf32> + %1 = memref.stride %t, %d : memref<4x4x?xf32> + return +} + +// ----- + +// CHECK-LABEL: func @offset +// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4x?xf32>) +func.func @offset(%t : memref<4x4x?xf32>) { + // CHECK: %{{.*}} = memref.offset %[[MEMREF]] : memref<4x4x?xf32> + %0 = "memref.offset"(%t) : (memref<4x4x?xf32>) -> index + + // CHECK: %{{.*}} = memref.offset %[[MEMREF]] : memref<4x4x?xf32> + %1 = memref.offset %t : memref<4x4x?xf32> + return +}