-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][emitc] Add 'emitc.while' and 'emitc.do' ops to the dialect #143008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: Vlad Lazar (Vladislave0-0) ChangesThis MR adds:
Patch is 52.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143008.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d4aea52a0d485..2390c4ef24cbc 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1345,7 +1345,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
}
def EmitC_YieldOp : EmitC_Op<"yield",
- [Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
+ [Pure, Terminator, ParentOneOf<["DoOp", "ExpressionOp", "ForOp", "IfOp", "SwitchOp", "WhileOp"]>]> {
let summary = "Block termination operation";
let description = [{
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
@@ -1572,4 +1572,156 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
let hasVerifier = 1;
}
+def EmitC_WhileOp : EmitC_Op<"while",
+ [HasOnlyGraphRegion, RecursiveMemoryEffects, NoRegionArguments, OpAsmOpInterface, NoTerminator]> {
+ let summary = "While operation";
+ let description = [{
+ The `emitc.while` operation represents a C/C++ while loop construct that
+ repeatedly executes a body region as long as a condition region evaluates to
+ true. The operation has two regions:
+
+ 1. A condition region that must yield a boolean value (i1)
+ 2. A body region that contains the loop body
+
+ The condition region is evaluated before each iteration. If it yields true,
+ the body region is executed. The loop terminates when the condition yields
+ false. The condition region must contain exactly one block that terminates
+ with an `emitc.yield` operation producing an i1 value.
+
+ Example:
+
+ ```mlir
+ emitc.func @foo(%arg0 : !emitc.ptr<i32>) {
+ %var = "emitc.variable"() <{value = 0 : i32}> : () -> !emitc.lvalue<i32>
+ %0 = emitc.literal "10" : i32
+ %1 = emitc.literal "1" : i32
+
+ emitc.while {
+ %var_load = load %var : <i32>
+ %res = emitc.cmp le, %var_load, %0 : (i32, i32) -> i1
+ emitc.yield %res : i1
+ } do {
+ emitc.verbatim "printf(\"%d\", *{});" args %arg0 : !emitc.ptr<i32>
+ %var_load = load %var : <i32>
+ %tmp_add = add %var_load, %1 : (i32, i32) -> i32
+ "emitc.assign"(%var, %tmp_add) : (!emitc.lvalue<i32>, i32) -> ()
+ }
+
+ return
+ }
+ ```
+
+ ```c++
+ // Code emitted for the operation above.
+ void foo(int32_t* v1) {
+ int32_t v2 = 0;
+ while (v2 <= 10) {
+ printf("%d", *v1);
+ int32_t v3 = v2;
+ int32_t v4 = v3 + 1;
+ v2 = v4;
+ }
+ return;
+ }
+ ```
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+ let regions = (region MaxSizedRegion<1>:$conditionRegion,
+ MaxSizedRegion<1>:$bodyRegion);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ Operation *getRootOp();
+
+ //===------------------------------------------------------------------===//
+ // OpAsmOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// EmitC ops in the body can omit their 'emitc.' prefix in the assembly.
+ static ::llvm::StringRef getDefaultDialect() {
+ return "emitc";
+ }
+ }];
+}
+
+def EmitC_DoOp : EmitC_Op<"do",
+ [RecursiveMemoryEffects, NoRegionArguments, OpAsmOpInterface, NoTerminator]> {
+ let summary = "Do-while operation";
+ let description = [{
+ The `emitc.do` operation represents a C/C++ do-while loop construct that
+ executes a body region first and then repeatedly executes it as long as a
+ condition region evaluates to true. The operation has two regions:
+
+ 1. A body region that contains the loop body
+ 2. A condition region that must yield a boolean value (i1)
+
+ Unlike a while loop, the body region is executed before the first evaluation
+ of the condition. The loop terminates when the condition yields false. The
+ condition region must contain exactly one block that terminates with an
+ `emitc.yield` operation producing an i1 value.
+
+ Example:
+
+ ```mlir
+ emitc.func @foo(%arg0 : !emitc.ptr<i32>) {
+ %var = "emitc.variable"() <{value = 0 : i32}> : () -> !emitc.lvalue<i32>
+ %0 = emitc.literal "10" : i32
+ %1 = emitc.literal "1" : i32
+
+ emitc.do {
+ emitc.verbatim "printf(\"%d\", *{});" args %arg0 : !emitc.ptr<i32>
+ %var_load = load %var : <i32>
+ %tmp_add = add %var_load, %1 : (i32, i32) -> i32
+ "emitc.assign"(%var, %tmp_add) : (!emitc.lvalue<i32>, i32) -> ()
+ } while {
+ %var_load = load %var : <i32>
+ %res = emitc.cmp le, %var_load, %0 : (i32, i32) -> i1
+ emitc.yield %res : i1
+ }
+
+ return
+ }
+ ```
+
+ ```c++
+ // Code emitted for the operation above.
+ void foo(int32_t* v1) {
+ int32_t v2 = 0;
+ do {
+ printf("%d", *v1);
+ int32_t v3 = v2;
+ int32_t v4 = v3 + 1;
+ v2 = v4;
+ } while (v2 <= 10);
+ return;
+ }
+ ```
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+ let regions = (region MaxSizedRegion<1>:$bodyRegion,
+ MaxSizedRegion<1>:$conditionRegion);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ Operation *getRootOp();
+
+ //===------------------------------------------------------------------===//
+ // OpAsmOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// EmitC ops in the body can omit their 'emitc.' prefix in the assembly.
+ static ::llvm::StringRef getDefaultDialect() {
+ return "emitc";
+ }
+ }];
+}
+
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 345e8494194eb..2e454070cb185 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -333,11 +333,246 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
return success();
}
+//==============================================================================
+
+// Lower scf::while to either emitc::while or emitc::do based on argument usage
+// patterns. Uses mutable variables to maintain loop state across iterations.
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = whileOp.getLoc();
+ MLIRContext *context = loc.getContext();
+
+ // Create variable storage for loop-carried values to enable imperative
+ // updates while maintaining SSA semantics at conversion boundaries.
+ SmallVector<Value> variables;
+ if (failed(
+ createInitVariables(whileOp, rewriter, variables, loc, context))) {
+ return failure();
+ }
+
+ // Select lowering strategy based on condition argument usage:
+ // - emitc.while when condition args match region inputs (direct mapping);
+ // - emitc.do when condition args differ (requires state synchronization).
+ Region &beforeRegion = adaptor.getBefore();
+ Block &beforeBlock = beforeRegion.front();
+ auto condOp = cast<scf::ConditionOp>(beforeRegion.back().getTerminator());
+
+ bool isDoOp = !llvm::equal(beforeBlock.getArguments(), condOp.getArgs());
+
+ LogicalResult result =
+ isDoOp ? lowerDoWhile(whileOp, variables, context, rewriter, loc)
+ : lowerWhile(whileOp, variables, context, rewriter, loc);
+
+ if (failed(result))
+ return failure();
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
+ resultVariables))) {
+ return rewriter.notifyMatchFailure(whileOp,
+ "Failed to create result variables");
+ }
+
+ rewriter.setInsertionPointAfter(whileOp);
+
+ // Transfer final loop state to result variables and get final SSA results.
+ SmallVector<Value> finalResults =
+ finalizeLoopResults(resultVariables, variables, rewriter, loc);
+
+ rewriter.replaceOp(whileOp, finalResults);
+ return success();
+ }
+
+private:
+ // Initialize variables for loop-carried values to enable state updates
+ // across iterations without SSA argument passing.
+ static LogicalResult createInitVariables(WhileOp whileOp,
+ ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &outVars,
+ Location loc, MLIRContext *context) {
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+
+ for (Value init : whileOp.getInits()) {
+ emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
+ loc, emitc::LValueType::get(init.getType()), noInit);
+ rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ outVars.push_back(var.getResult());
+ }
+
+ return success();
+ }
+
+ // Transition from SSA block arguments to variable-based state management by
+ // replacing argument uses with variable loads and cleaning up block
+ // interface.
+ void replaceBlockArgsWithVarLoads(Block *block, ArrayRef<Value> vars,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ rewriter.setInsertionPointToStart(block);
+
+ for (auto [arg, var] : llvm::zip(block->getArguments(), vars)) {
+ Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+ Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+ arg.replaceAllUsesWith(load);
+ }
+
+ // Remove arguments after replacement to simplify block structure.
+ block->eraseArguments(0, block->getNumArguments());
+ }
+
+ // Convert SCF yield terminators to imperative assignments to update loop
+ // variables, maintaining loop semantics while transitioning to emitc model.
+ void processYieldTerminator(Operation *terminator, ArrayRef<Value> vars,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ auto yieldOp = cast<scf::YieldOp>(terminator);
+ SmallVector<Value> yields(yieldOp.getOperands());
+ rewriter.eraseOp(yieldOp);
+
+ rewriter.setInsertionPointToEnd(yieldOp->getBlock());
+ for (auto [var, val] : llvm::zip(vars, yields))
+ rewriter.create<emitc::AssignOp>(loc, var, val);
+ }
+
+ // Transfers final loop state from mutable variables to result variables,
+ // then returns the final SSA values to replace the original scf::while
+ // results.
+ static SmallVector<Value>
+ finalizeLoopResults(ArrayRef<Value> resultVariables,
+ ArrayRef<Value> loopVariables,
+ ConversionPatternRewriter &rewriter, Location loc) {
+ // Transfer final loop state to result variables to bridge imperative loop
+ // variables with SSA result expectations of the original op.
+ for (auto [resultVar, var] : llvm::zip(resultVariables, loopVariables)) {
+ Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+ Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+ rewriter.create<emitc::AssignOp>(loc, resultVar, load);
+ }
+
+ // Replace op with loaded values to integrate with converted SSA graph.
+ SmallVector<Value> finalResults;
+ for (Value resultVar : resultVariables) {
+ Type loadedType =
+ cast<emitc::LValueType>(resultVar.getType()).getValueType();
+ finalResults.push_back(
+ rewriter.create<emitc::LoadOp>(loc, loadedType, resultVar));
+ }
+
+ return finalResults;
+ }
+
+ // Direct lowering to emitc.while when condition arguments match region
+ // inputs.
+ LogicalResult lowerWhile(WhileOp whileOp, ArrayRef<Value> vars,
+ MLIRContext *context,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ auto loweredWhile = rewriter.create<emitc::WhileOp>(loc);
+
+ // Lower before region to condition region.
+ rewriter.inlineRegionBefore(whileOp.getBefore(),
+ loweredWhile.getConditionRegion(),
+ loweredWhile.getConditionRegion().end());
+
+ Block *condBlock = &loweredWhile.getConditionRegion().front();
+ replaceBlockArgsWithVarLoads(condBlock, vars, rewriter, loc);
+
+ Operation *condTerminator =
+ loweredWhile.getConditionRegion().back().getTerminator();
+ auto condOp = cast<scf::ConditionOp>(condTerminator);
+ rewriter.setInsertionPoint(condOp);
+ Value condition = rewriter.getRemappedValue(condOp.getCondition());
+ rewriter.create<emitc::YieldOp>(condOp.getLoc(), condition);
+ rewriter.eraseOp(condOp);
+
+ // Lower after region to body region.
+ rewriter.inlineRegionBefore(whileOp.getAfter(),
+ loweredWhile.getBodyRegion(),
+ loweredWhile.getBodyRegion().end());
+
+ Block *bodyBlock = &loweredWhile.getBodyRegion().front();
+ replaceBlockArgsWithVarLoads(bodyBlock, vars, rewriter, loc);
+
+ // Convert scf.yield to variable assignments for state updates.
+ processYieldTerminator(bodyBlock->getTerminator(), vars, rewriter, loc);
+
+ return success();
+ }
+
+ // Lower to emitc.do when condition arguments differ from region inputs.
+ LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> vars,
+ MLIRContext *context,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ Type i1Type = IntegerType::get(context, 1);
+ auto globalCondition =
+ rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
+ Value conditionVal = globalCondition.getResult();
+
+ auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+
+ // Lower before region as body.
+ rewriter.inlineRegionBefore(whileOp.getBefore(), loweredDo.getBodyRegion(),
+ loweredDo.getBodyRegion().end());
+
+ Block *bodyBlock = &loweredDo.getBodyRegion().front();
+ replaceBlockArgsWithVarLoads(bodyBlock, vars, rewriter, loc);
+
+ // Convert scf.condition to condition variable assignment.
+ Operation *condTerminator =
+ loweredDo.getBodyRegion().back().getTerminator();
+ scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
+ rewriter.setInsertionPoint(condOp);
+ Value condition = rewriter.getRemappedValue(condOp.getCondition());
+ rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+
+ // Wrap body region in conditional to preserve scf semantics.
+ auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+
+ // Lower after region as then-block of conditional.
+ rewriter.inlineRegionBefore(whileOp.getAfter(), ifOp.getBodyRegion(),
+ ifOp.getBodyRegion().begin());
+
+ if (!ifOp.getBodyRegion().empty()) {
+ Block *ifBlock = &ifOp.getBodyRegion().front();
+
+ // Handle argument mapping from condition op to body region.
+ auto args = condOp.getArgs();
+ for (auto [arg, val] : llvm::zip(ifBlock->getArguments(), args))
+ arg.replaceAllUsesWith(rewriter.getRemappedValue(val));
+
+ ifBlock->eraseArguments(0, ifBlock->getNumArguments());
+
+ // Convert scf.yield to variable assignments for state updates.
+ processYieldTerminator(ifBlock->getTerminator(), vars, rewriter, loc);
+ rewriter.create<emitc::YieldOp>(loc);
+ }
+
+ rewriter.eraseOp(condOp);
+
+ // Create condition region that loads from the flag variable.
+ Block *condBlock = rewriter.createBlock(&loweredDo.getConditionRegion());
+ rewriter.setInsertionPointToStart(condBlock);
+ Value cond = rewriter.create<emitc::LoadOp>(loc, i1Type, conditionVal);
+ rewriter.create<emitc::YieldOp>(loc, cond);
+
+ return success();
+ }
+};
+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ForLowering>(typeConverter, patterns.getContext());
patterns.add<IfLowering>(typeConverter, patterns.getContext());
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
+ patterns.add<WhileLowering>(typeConverter, patterns.getContext());
}
void SCFToEmitCPass::runOnOperation() {
@@ -354,7 +589,8 @@ void SCFToEmitCPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
- target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
+ target
+ .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 1709654b90138..70a7150b818ff 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -900,10 +900,12 @@ LogicalResult emitc::YieldOp::verify() {
Value result = getResult();
Operation *containingOp = getOperation()->getParentOp();
- if (result && containingOp->getNumResults() != 1)
+ if (result && containingOp->getNumResults() != 1 &&
+ !isa<WhileOp, DoOp>(containingOp))
return emitOpError() << "yields a value not returned by parent";
- if (!result && containingOp->getNumResults() != 0)
+ if (!result && containingOp->getNumResults() != 0 &&
+ !isa<WhileOp, DoOp>(containingOp))
return emitOpError() << "does not yield a value to be returned by parent";
return success();
@@ -1394,6 +1396,116 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
builder.getNamedAttr("id", builder.getStringAttr(id)));
}
+//===----------------------------------------------------------------------===//
+// Common functions for WhileOp and DoOp
+//===----------------------------------------------------------------------===//
+
+static Operation *getRootOpFromLoopCondition(Region &condRegion) {
+ auto yieldOp = cast<emitc::YieldOp>(condRegion.front().getTerminator());
+ return yieldOp.getResult().getDefiningOp();
+}
+
+static LogicalResult verifyLoopRegions(Operation &op, Region &condition,
+ Region &body) {
+ if (condition.empty())
+ return op.emitOpError("condition region cannot be empty");
+
+ Block &condBlock = condition.front();
+ for (Operation &inner : condBlock.without_terminator()) {
+ if (!inner.hasTrait<OpTrait::emitc::CExpression>())
+ return op.emitOpError(
+ "expected all operations in condition region must implement "
+ "CExpression trait, but ")
+ << inner.getName() << " does not";
+ }
+
+ auto condYield = dyn_cast<emitc::YieldOp>(condBlock.back());
+ if (!condYield)
+ return op.emitOpError(
+ "expected condition region to end with emitc.yield, but got ")
+ << condBlock.back().getName();
+
+ if (condYield.getNumOperands() != 1 ||
+ !condYield.getOperand(0).getType().isInteger(1))
+ return op.emitOpError("condition region must yield a single i1 value");
+
+ if (body.empty())
+ return op.emitOpError("body region cannot be empty");
+
+ Block &bodyBlock = body.front();
+ if (auto bodyYield = dyn_cast<emitc::YieldOp>(bodyBlock.back()))
+ if (bodyYield.getNumOperands() != 0)
+ return op.emitOpError(
+ "expected body region to return 0 values, but body returns ")
+ << bodyYield.getNumOperands();
+
+ return success();
+}
+
+static void printLoop(OpAsmPrinter &p, Operation *self, Region &f...
[truncated]
|
61c2975
to
df7f06a
Compare
While working on the #143894 - This MR fixes it. Check out this MR first, please. |
df7f06a
to
3695d3d
Compare
@aniragil, gentle ping. I have resolved all what you requested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aniragil, gentle ping. I have resolved all what you requested
Thanks, taking a look. Some initial comments inline.
auto exprOp = rewriter.create<emitc::ExpressionOp>(loc, TypeRange{i1Type}); | ||
Region &exprRegion = exprOp.getBodyRegion(); | ||
|
||
rewriter.inlineRegionBefore(whileOp.getBefore(), exprRegion, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The before
block is not guaranteed to contain a valid C expression, e.g.:
func.func @double_use(%p : !emitc.ptr<i32>) -> i32 {
%init = emitc.literal "1.0" : i32
%var = emitc.literal "1.0" : i32
%exit = emitc.literal "10.0" : i32
%res = scf.while (%arg1 = %init) : (i32) -> i32 {
%used_twice = emitc.call @payload_with_side_effect(%arg1, %p) : (i32, !emitc.ptr<i32>) -> i32
%prod = emitc.add %used_twice, %used_twice : (i32, i32) -> i32
%sum = emitc.add %arg1, %prod : (i32, i32) -> i32
%condition = emitc.cmp lt, %sum, %exit : (i32, i32) -> i1
scf.condition(%condition) %arg1 : i32
} do {
^bb0(%arg2: i32):
%next_arg1 = emitc.call @payload_do(%arg2) : (i32) -> i32
scf.yield %next_arg1 : i32
}
return %res : i32
}
It may also include ops not related to the condition at all. The emitc.expression
op can in principle be extended to support such sequences using the comma operator, but it currently doesn't (I'm also not sure it'd be very aesthetic)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The before block is not guaranteed to contain a valid C expression, e.g.:
I wonder what would be the correct way for those things to get translated, given that it may contain operations that can translate to CExpression
, but at the time of the transformation, are not CExpression
yet, like if you have arith
, etc. I thought, that you just translate it anyway, and then it may fail later on due to incomplete translations. Like it's not like emitc.while
won't fail now when translating to cpp
, since it can not be emitted
at all.
It may also include ops not related to the condition at all. The emitc.expression op can in principle be extended to support such sequences using the comma operator, but it currently doesn't (I'm also not sure it'd be very aesthetic)
This can be solved by routing it to the different loop lowering logic, do/while
, this one can model pretty much everything, and CExpression
is not really required for Cond
in such case. It's already present, and I remember telling @Vladislave0-0 about such cases, but I don't see it being here for calls
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
before
block is not guaranteed to contain a valid C expression, e.g.:func.func @double_use(%p : !emitc.ptr<i32>) -> i32 { %init = emitc.literal "1.0" : i32 %var = emitc.literal "1.0" : i32 %exit = emitc.literal "10.0" : i32 %res = scf.while (%arg1 = %init) : (i32) -> i32 { %used_twice = emitc.call @payload_with_side_effect(%arg1, %p) : (i32, !emitc.ptr<i32>) -> i32 %prod = emitc.add %used_twice, %used_twice : (i32, i32) -> i32 %sum = emitc.add %arg1, %prod : (i32, i32) -> i32 %condition = emitc.cmp lt, %sum, %exit : (i32, i32) -> i1 scf.condition(%condition) %arg1 : i32 } do { ^bb0(%arg2: i32): %next_arg1 = emitc.call @payload_do(%arg2) : (i32) -> i32 scf.yield %next_arg1 : i32 } return %res : i32 }It may also include ops not related to the condition at all. The
emitc.expression
op can in principle be extended to support such sequences using the comma operator, but it currently doesn't (I'm also not sure it'd be very aesthetic)
Thank you for noticing this bug. Given that we can not determine whether it'll really translate to CExpression
that won't have side effects or translate at all, what do you think about always translating to do-while
style loop as we do now when handling scf.while
's do-while
style condition,
bool isDoOp = !llvm::equal(beforeBlock.getArguments(), condOp.getArgs());
LogicalResult result =
isDoOp ? lowerDoWhile(whileOp, variables, context, rewriter, loc)
: lowerWhile(whileOp, variables, context, rewriter, loc);
where such issues are eliminated due to translating into while
body, which is not required to be a CExpression
? That way the translation of the scf.while
itself will always work and it'll behave pretty much the same as it does right now.
As a downside, the new emitc.while
will be "dead", since it's not part of the translation from the higher dialects, but tbf we can determine after form-expression
pass, which emitc.do
can be converted to while
operation, since we have a very specific pattern of do { cond_res = cond; if (cond_res) while (cond_res) }
. And do that as a separate optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what would be the correct way for those things to get translated, given that it may contain operations that can translate to CExpression, but at the time of the transformation, are not CExpression yet, like if you have arith, etc. I thought, that you just translate it anyway, and then it may fail later on due to incomplete translations. Like it's not like emitc.while won't fail now when translating to cpp, since it can not be emitted at all.
Sorry, I wasn't clear: I didn't mean ops of other dialects like arith
, but emitc
ops that do not form a single expression (at least not without using C's comma
operator). For instance, in the following code:
while (true) {
a[i+1] = b[i*7] + 3;
foo(t[i] / 5);
int c = k[i + 11];
if (c)
break;
// do some more work with loads, stores, calls etc.
}
the "before" section includes computations unrelated to the condition, so it's not a "classic" while loop where the exit condition is checked before any computation. Since the while
condition clause requires a single expression, putting all this code there requires turning it into a single expression. This can be done using the comma operator (which emitc
currently doesn't support), i.e.
while (a[i+1] = b[i*7] + 3, foo(t[i] / 5), k[i + 11]) {
// do some more work with loads, stores, calls etc.
}
But that's far less clear than the original structure IMO.
what do you think about always translating to do-while style loop ... where such issues are eliminated due to translating into while body, which is not required to be a CExpression? That way the translation of the scf.while itself will always work and it'll behave pretty much the same as it does right now. As a downside, the new emitc.while will be "dead", since it's not part of the translation from the higher dialects, but tbf we can determine after form-expression pass, which emitc.do can be converted to while operation, since we have a very specific pattern of do { cond_res = cond; if (cond_res) while (cond_res) }. And do that as a separate optimization.
Agreed, best to start with a simple and robust lowering. We shouldn't have dead ops, so I'd do this in stages - first a patch to introduce a single loop op to lower scf.while
in a unified manner. I'm OK with starting either with emitc.do
or emitc.while
(with the condition variable initialized to 1
). In any case, lowering can create in the condition region a simple emitc.expression
that only loads from the condition variable.
The form-expressions pass (if executed) should indeed fold the computation of the condition in the loop body.
A second patch can then introduce a new pass to optimize the loops in both directions, i.e. identify the emitc.expression
that sets the condition variable and push it down do {} while (/*HERE*/)
or push it up while (/*HERE*/) {}
if possible, removing the condition variable. WDYT @kchibisov?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing the condition variable. WDYT @kchibisov?
Yeah, that's what I've suggested. Like first translate to something like
bool cond = true;
do {
cond = <SCF.COND>;
if (cond) {
<SCF.BODY>
}
} while (cond)
(well, with variations), as it does now IIRC. And then just add a pass (if desired) to optimize this into something more readable. This pattern is easily identified, so wouldn't worry much about it.
The good thing that this pattern will always translate.
Variation with while
is also possible, but semantically, it's a do
loop, since you always have it started with true
.
3695d3d
to
8a78dac
Compare
This MR adds: - 'emitc::WhileOp' and 'emitc::DoOp' to the EmitC dialect - Emission of the corresponding ops in the CppEmitter - Conversion from the SCF dialect to the EmitC dialect for the ops - Corresponding tests
Change the canonical structure of a conditional region: - The condition region must contain exactly one block with: 1. An `emitc.expression` operation producing an i1 value 2. An `emitc.yield` passing through the expression result - The body region must not yield any values
8a78dac
to
26c160b
Compare
Deleted lovering from scf.while to emitc.while. Minor code style changes have been made.
26c160b
to
8a69d05
Compare
This MR adds: