diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -22,12 +22,6 @@ class OwningRewritePatternList; -/// Creates an instance of the ExpandAtomic pass. -std::unique_ptr createExpandAtomicPass(); - -void populateExpandMemRefReshapePattern(OwningRewritePatternList &patterns, - MLIRContext *ctx); - void populateExpandTanhPattern(OwningRewritePatternList &patterns, MLIRContext *ctx); @@ -41,15 +35,16 @@ /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); -/// Creates an instance of the StdExpandDivs pass that legalizes Std -/// dialect Divs to be convertible to StaLLVMndard. For example, -/// `std.ceildivi_signed` get transformed to a number of std operations, -/// which can be lowered to LLVM. -std::unique_ptr createStdExpandDivsPass(); +/// Creates an instance of the StdExpand pass that legalizes Std +/// dialect ops to be convertible to LLVM. For example, +/// `std.ceildivi_signed` gets transformed to a number of std operations, +/// which can be lowered to LLVM; `memref_reshape` gets converted to +/// `memref_reinterpret_cast`. +std::unique_ptr createStdExpandOpsPass(); /// Collects a set of patterns to rewrite ops within the Std dialect. -void populateStdExpandDivsRewritePatterns(MLIRContext *context, - OwningRewritePatternList &patterns); +void populateStdExpandOpsPatterns(MLIRContext *context, + OwningRewritePatternList &patterns); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -11,20 +11,15 @@ include "mlir/Pass/PassBase.td" -def ExpandAtomic : FunctionPass<"expand-atomic"> { - let summary = "Expands AtomicRMWOp into GenericAtomicRMWOp."; - let constructor = "mlir::createExpandAtomicPass()"; -} - def StdBufferize : FunctionPass<"std-bufferize"> { let summary = "Bufferize the std dialect"; let constructor = "mlir::createStdBufferizePass()"; let dependentDialects = ["scf::SCFDialect"]; } -def StdExpandDivs : FunctionPass<"std-expand-divs"> { - let summary = "Legalize div std dialect operations to be convertible to LLVM."; - let constructor = "mlir::createStdExpandDivsPass()"; +def StdExpandOps : FunctionPass<"std-expand"> { + let summary = "Legalize std operations to be convertible to LLVM."; + let constructor = "mlir::createStdExpandOpsPass()"; } def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { diff --git a/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir --- a/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir +++ b/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-expand-divs -convert-vector-to-llvm | \ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-expand -convert-vector-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -79,4 +79,4 @@ // CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 ) // CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) // CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) -// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) \ No newline at end of file +// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -13,6 +13,5 @@ LINK_LIBS PUBLIC MLIRLLVMIR - MLIRStandardOpsTransforms MLIRTransforms ) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -16,7 +16,6 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -4079,7 +4078,6 @@ populateStdToLLVMFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatterns(converter, patterns); - populateExpandMemRefReshapePattern(patterns, &converter.getContext()); } /// Convert a non-empty list of types to be returned from a function into a diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,11 +1,9 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms Bufferize.cpp - ExpandAtomic.cpp - ExpandMemRefReshape.cpp + ExpandOps.cpp ExpandTanh.cpp FuncBufferize.cpp FuncConversions.cpp - StdExpandDivs.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp +++ /dev/null @@ -1,95 +0,0 @@ -//===- ExpandAtomic.cpp - Code to perform expanding atomic ops ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements expansion of AtomicRMWOp into GenericAtomicRMWOp. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; - -namespace { - -/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with -/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to -/// `generic_atomic_rmw` with the expanded code. -/// -/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 -/// -/// will be lowered to -/// -/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> { -/// ^bb0(%current: f32): -/// %cmp = cmpf "ogt", %current, %fval : f32 -/// %new_value = select %cmp, %current, %fval : f32 -/// atomic_yield %new_value : f32 -/// } -struct AtomicRMWOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AtomicRMWOp op, - PatternRewriter &rewriter) const final { - CmpFPredicate predicate; - switch (op.kind()) { - case AtomicRMWKind::maxf: - predicate = CmpFPredicate::OGT; - break; - case AtomicRMWKind::minf: - predicate = CmpFPredicate::OLT; - break; - default: - return failure(); - } - - auto loc = op.getLoc(); - auto genericOp = - rewriter.create(loc, op.memref(), op.indices()); - OpBuilder bodyBuilder = - OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); - - Value lhs = genericOp.getCurrentValue(); - Value rhs = op.value(); - Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); - Value select = bodyBuilder.create(loc, cmp, lhs, rhs); - bodyBuilder.create(loc, select); - - rewriter.replaceOp(op, genericOp.getResult()); - return success(); - } -}; - -struct ExpandAtomic : public ExpandAtomicBase { - void runOnFunction() override { - OwningRewritePatternList patterns; - patterns.insert(&getContext()); - - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addDynamicallyLegalOp([](AtomicRMWOp op) { - return op.kind() != AtomicRMWKind::maxf && - op.kind() != AtomicRMWKind::minf; - }); - if (failed(mlir::applyPartialConversion(getFunction(), target, - std::move(patterns)))) - signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr mlir::createExpandAtomicPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp +++ /dev/null @@ -1,70 +0,0 @@ -//===- ExpandMemRefReshape.cpp - Code to perform expanding memref_reshape -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements expansion of MemRefReshapeOp into -// MemRefReinterpretCastOp. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; - -namespace { - -/// Converts `memref_reshape` that has a target shape of a statically-known -/// size to `memref_reinterpret_cast`. -struct MemRefReshapeOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MemRefReshapeOp op, - PatternRewriter &rewriter) const final { - auto shapeType = op.shape().getType().cast(); - if (!shapeType.hasStaticShape()) - return failure(); - - int64_t rank = shapeType.cast().getDimSize(0); - SmallVector sizes, strides; - sizes.resize(rank); - strides.resize(rank); - - Location loc = op.getLoc(); - Value stride = rewriter.create(loc, 1); - for (int i = rank - 1; i >= 0; --i) { - Value index = rewriter.create(loc, i); - Value size = rewriter.create(loc, op.shape(), index); - if (!size.getType().isa()) - size = rewriter.create(loc, size, rewriter.getIndexType()); - sizes[i] = size; - strides[i] = stride; - if (i > 0) - stride = rewriter.create(loc, stride, size); - } - SmallVector staticSizes(rank, ShapedType::kDynamicSize); - SmallVector staticStrides(rank, - ShapedType::kDynamicStrideOrOffset); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.source(), /*staticOffset = */ 0, staticSizes, - staticStrides, /*offset=*/llvm::None, sizes, strides); - return success(); - } -}; - -} // namespace - -void mlir::populateExpandMemRefReshapePattern( - OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); -} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp rename from mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp rename to mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -21,6 +21,94 @@ namespace { +/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with +/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to +/// `generic_atomic_rmw` with the expanded code. +/// +/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 +/// +/// will be lowered to +/// +/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> { +/// ^bb0(%current: f32): +/// %cmp = cmpf "ogt", %current, %fval : f32 +/// %new_value = select %cmp, %current, %fval : f32 +/// atomic_yield %new_value : f32 +/// } +struct AtomicRMWOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtomicRMWOp op, + PatternRewriter &rewriter) const final { + CmpFPredicate predicate; + switch (op.kind()) { + case AtomicRMWKind::maxf: + predicate = CmpFPredicate::OGT; + break; + case AtomicRMWKind::minf: + predicate = CmpFPredicate::OLT; + break; + default: + return failure(); + } + + auto loc = op.getLoc(); + auto genericOp = + rewriter.create(loc, op.memref(), op.indices()); + OpBuilder bodyBuilder = + OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); + + Value lhs = genericOp.getCurrentValue(); + Value rhs = op.value(); + Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); + Value select = bodyBuilder.create(loc, cmp, lhs, rhs); + bodyBuilder.create(loc, select); + + rewriter.replaceOp(op, genericOp.getResult()); + return success(); + } +}; + +/// Converts `memref_reshape` that has a target shape of a statically-known +/// size to `memref_reinterpret_cast`. +struct MemRefReshapeOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MemRefReshapeOp op, + PatternRewriter &rewriter) const final { + auto shapeType = op.shape().getType().cast(); + if (!shapeType.hasStaticShape()) + return failure(); + + int64_t rank = shapeType.cast().getDimSize(0); + SmallVector sizes, strides; + sizes.resize(rank); + strides.resize(rank); + + Location loc = op.getLoc(); + Value stride = rewriter.create(loc, 1); + for (int i = rank - 1; i >= 0; --i) { + Value index = rewriter.create(loc, i); + Value size = rewriter.create(loc, op.shape(), index); + if (!size.getType().isa()) + size = rewriter.create(loc, size, rewriter.getIndexType()); + sizes[i] = size; + strides[i] = stride; + if (i > 0) + stride = rewriter.create(loc, stride, size); + } + SmallVector staticSizes(rank, ShapedType::kDynamicSize); + SmallVector staticStrides(rank, + ShapedType::kDynamicStrideOrOffset); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.source(), /*staticOffset = */ 0, staticSizes, + staticStrides, /*offset=*/llvm::None, sizes, strides); + return success(); + } +}; + /// Expands SignedCeilDivIOP (n, m) into /// 1) x = (m > 0) ? -1 : 1 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) @@ -121,35 +209,40 @@ } }; -} // namespace +struct StdExpandOpsPass : public StdExpandOpsBase { + void runOnFunction() override { + MLIRContext &ctx = getContext(); -namespace { -struct StdExpandDivs : public StdExpandDivsBase { - void runOnFunction() override; -}; -} // namespace + OwningRewritePatternList patterns; + populateStdExpandOpsPatterns(&ctx, patterns); -void StdExpandDivs::runOnFunction() { - MLIRContext &ctx = getContext(); + ConversionTarget target(getContext()); - OwningRewritePatternList patterns; - populateStdExpandDivsRewritePatterns(&ctx, patterns); + target.addLegalDialect(); + target.addDynamicallyLegalOp([](AtomicRMWOp op) { + return op.kind() != AtomicRMWKind::maxf && + op.kind() != AtomicRMWKind::minf; + }); + target.addDynamicallyLegalOp([](MemRefReshapeOp op) { + return !op.shape().getType().cast().hasStaticShape(); + }); + target.addIllegalOp(); + target.addIllegalOp(); + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + } +}; - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); - if (failed( - applyPartialConversion(getFunction(), target, std::move(patterns)))) - signalPassFailure(); -} +} // namespace -void mlir::populateStdExpandDivsRewritePatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert( +void mlir::populateStdExpandOpsPatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + patterns.insert( context); } -std::unique_ptr mlir::createStdExpandDivsPass() { - return std::make_unique(); +std::unique_ptr mlir::createStdExpandOpsPass() { + return std::make_unique(); } diff --git a/mlir/test/Dialect/Standard/expand-atomic.mlir b/mlir/test/Dialect/Standard/expand-atomic.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Standard/expand-atomic.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: mlir-opt %s -expand-atomic -split-input-file | FileCheck %s - -// CHECK-LABEL: func @atomic_rmw_to_generic -// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) -func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { - %x = atomic_rmw "maxf" %f, %F[%i] : (f32, memref<10xf32>) -> f32 - return %x : f32 -} -// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { -// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): -// CHECK: [[CMP:%.*]] = cmpf "ogt", [[CUR_VAL]], [[f]] : f32 -// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32 -// CHECK: atomic_yield [[SELECT]] : f32 -// CHECK: } -// CHECK: return %0 : f32 - -// ----- - -// CHECK-LABEL: func @atomic_rmw_no_conversion -func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { - %x = atomic_rmw "addf" %f, %F[%i] : (f32, memref<10xf32>) -> f32 - return %x : f32 -} -// CHECK-NOT: generic_atomic_rmw diff --git a/mlir/test/Dialect/Standard/expand-memref-reshape.mlir b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Standard/expand-memref-reshape.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: mlir-opt %s -test-expand-memref-reshape | FileCheck %s - -// CHECK-LABEL: func @memref_reshape( -func @memref_reshape(%input: memref<*xf32>, - %shape: memref<3xi32>) -> memref { - %result = memref_reshape %input(%shape) - : (memref<*xf32>, memref<3xi32>) -> memref - return %result : memref -} -// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>, -// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref { -// CHECK: [[C2:%.*]] = constant 2 : index -// CHECK: [[C1:%.*]] = constant 1 : index -// CHECK: [[C0:%.*]] = constant 0 : index -// CHECK: [[DIM_2:%.*]] = load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32> -// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index -// CHECK: [[DIM_1:%.*]] = load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32> -// CHECK: [[SIZE_1:%.*]] = index_cast [[DIM_1]] : i32 to index -// CHECK: [[STRIDE_0:%.*]] = muli [[SIZE_2]], [[SIZE_1]] : index -// CHECK: [[DIM_0:%.*]] = load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32> -// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index - -// CHECK: [[RESULT:%.*]] = memref_reinterpret_cast [[SRC]] -// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]], -// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[SIZE_2]], [[C1]]] -// CHECK-SAME: : memref<*xf32> to memref diff --git a/mlir/test/Dialect/Standard/std-expand-divs.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir rename from mlir/test/Dialect/Standard/std-expand-divs.mlir rename to mlir/test/Dialect/Standard/expand-ops.mlir --- a/mlir/test/Dialect/Standard/std-expand-divs.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -1,29 +1,27 @@ -// RUN: mlir-opt -std-expand-divs %s -split-input-file | FileCheck %s +// RUN: mlir-opt -std-expand %s -split-input-file | FileCheck %s -// Test floor divide with signed integer -// CHECK-LABEL: func @floordivi -// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { -func @floordivi(%arg0: i32, %arg1: i32) -> (i32) { - %res = floordivi_signed %arg0, %arg1 : i32 - return %res : i32 -// CHECK: [[ONE:%.+]] = constant 1 : i32 -// CHECK: [[ZERO:%.+]] = constant 0 : i32 -// CHECK: [[MIN1:%.+]] = constant -1 : i32 -// CHECK: [[CMP1:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32 -// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : i32 -// CHECK: [[TRUE1:%.+]] = subi [[X]], [[ARG0]] : i32 -// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32 -// CHECK: [[TRUE3:%.+]] = subi [[MIN1]], [[TRUE2]] : i32 -// CHECK: [[FALSE:%.+]] = divi_signed [[ARG0]], [[ARG1]] : i32 -// CHECK: [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32 -// CHECK: [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32 -// CHECK: [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32 -// CHECK: [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32 -// CHECK: [[TERM1:%.+]] = and [[NNEG]], [[MPOS]] : i1 -// CHECK: [[TERM2:%.+]] = and [[NPOS]], [[MNEG]] : i1 -// CHECK: [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 +// CHECK-LABEL: func @atomic_rmw_to_generic +// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) +func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { + %x = atomic_rmw "maxf" %f, %F[%i] : (f32, memref<10xf32>) -> f32 + return %x : f32 } +// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): +// CHECK: [[CMP:%.*]] = cmpf "ogt", [[CUR_VAL]], [[f]] : f32 +// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32 +// CHECK: atomic_yield [[SELECT]] : f32 +// CHECK: } +// CHECK: return %0 : f32 + +// ----- + +// CHECK-LABEL: func @atomic_rmw_no_conversion +func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { + %x = atomic_rmw "addf" %f, %F[%i] : (f32, memref<10xf32>) -> f32 + return %x : f32 +} +// CHECK-NOT: generic_atomic_rmw // ----- @@ -54,3 +52,62 @@ // CHECK: [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1 // CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32 } + +// ----- + +// Test floor divide with signed integer +// CHECK-LABEL: func @floordivi +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { +func @floordivi(%arg0: i32, %arg1: i32) -> (i32) { + %res = floordivi_signed %arg0, %arg1 : i32 + return %res : i32 +// CHECK: [[ONE:%.+]] = constant 1 : i32 +// CHECK: [[ZERO:%.+]] = constant 0 : i32 +// CHECK: [[MIN1:%.+]] = constant -1 : i32 +// CHECK: [[CMP1:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : i32 +// CHECK: [[TRUE1:%.+]] = subi [[X]], [[ARG0]] : i32 +// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32 +// CHECK: [[TRUE3:%.+]] = subi [[MIN1]], [[TRUE2]] : i32 +// CHECK: [[FALSE:%.+]] = divi_signed [[ARG0]], [[ARG1]] : i32 +// CHECK: [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32 +// CHECK: [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32 +// CHECK: [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[TERM1:%.+]] = and [[NNEG]], [[MPOS]] : i1 +// CHECK: [[TERM2:%.+]] = and [[NPOS]], [[MNEG]] : i1 +// CHECK: [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1 +// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 +} + +// ----- + +// CHECK-LABEL: func @memref_reshape( +func @memref_reshape(%input: memref<*xf32>, + %shape: memref<3xi32>) -> memref { + %result = memref_reshape %input(%shape) + : (memref<*xf32>, memref<3xi32>) -> memref + return %result : memref +} +// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>, +// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref { + +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[DIM_2:%.*]] = load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32> +// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index +// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[SIZE_2]] : index + +// CHECK: [[C1_:%.*]] = constant 1 : index +// CHECK: [[DIM_1:%.*]] = load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32> +// CHECK: [[SIZE_1:%.*]] = index_cast [[DIM_1]] : i32 to index +// CHECK: [[STRIDE_0:%.*]] = muli [[STRIDE_1]], [[SIZE_1]] : index + +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[DIM_0:%.*]] = load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32> +// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index + +// CHECK: [[RESULT:%.*]] = memref_reinterpret_cast [[SRC]] +// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]], +// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]] +// CHECK-SAME: : memref<*xf32> to memref diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,7 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms TestAffineLoopParametricTiling.cpp - TestExpandMemRefReshape.cpp TestExpandTanh.cpp TestCallGraph.cpp TestConstantFold.cpp diff --git a/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp deleted file mode 100644 --- a/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp +++ /dev/null @@ -1,39 +0,0 @@ -//===- TestExpandMemRefReshape.cpp - Test expansion of memref_reshape -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains test passes for expanding memref reshape. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -struct TestExpandMemRefReshapePass - : public PassWrapper { - void runOnFunction() override; -}; -} // end anonymous namespace - -void TestExpandMemRefReshapePass::runOnFunction() { - OwningRewritePatternList patterns; - populateExpandMemRefReshapePattern(patterns, &getContext()); - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); -} - -namespace mlir { -namespace test { -void registerTestExpandMemRefReshapePass() { - PassRegistration pass( - "test-expand-memref-reshape", "Test expanding memref reshape"); -} -} // namespace test -} // namespace mlir diff --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir --- a/mlir/test/mlir-cpu-runner/memref_reshape.mlir +++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \ +// RUN: mlir-opt %s -convert-scf-to-std -std-expand -convert-std-to-llvm \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -66,7 +66,6 @@ void registerTestDialect(DialectRegistry &); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); -void registerTestExpandMemRefReshapePass(); void registerTestExpandTanhPass(); void registerTestFinalizingBufferizePass(); void registerTestGpuParallelLoopMappingPass(); @@ -131,7 +130,6 @@ test::registerTestConvVectorization(); test::registerTestDominancePass(); test::registerTestDynamicPipelinePass(); - test::registerTestExpandMemRefReshapePass(); test::registerTestExpandTanhPass(); test::registerTestFinalizingBufferizePass(); test::registerTestGpuParallelLoopMappingPass();